vision-agent 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,6 +56,7 @@ Example 2: {{"Parameters":[{{"input": [1,2,3]}}, {{"input": [2,3,4]}}]}}
56
56
 
57
57
  These are logs of previous questions and answers:
58
58
  {previous_log}
59
+
59
60
  This is the current user's question: {question}
60
61
  This is the API tool documentation: {tool_usage}
61
62
  Output: """
@@ -67,15 +68,22 @@ Please note that:
67
68
  2. We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
68
69
  3. If the API tool does not provide useful information in the response, please answer with your knowledge.
69
70
  4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers.
71
+
70
72
  These are logs of previous questions and answers:
71
73
  {previous_log}
74
+
72
75
  This is the user's question: {question}
76
+
73
77
  This is the response output by the API tool:
74
78
  {call_results}
79
+
75
80
  We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
76
81
  Output: """
77
82
 
78
83
  ANSWER_SUMMARIZE = """We break down a complex user's problems into simple subtasks and provide answers to each simple subtask. You need to organize these answers to each subtask and form a self-consistent final answer to the user's question.
79
84
  This is the user's question: {question}
80
- These are subtasks and their answers: {answers}
85
+
86
+ These are subtasks and their answers:
87
+ {answers}
88
+
81
89
  Final answer: """
@@ -238,12 +238,20 @@ class Reflexion(Agent):
238
238
  self._build_agent_prompt(question, reflections, scratchpad)
239
239
  )
240
240
  )
241
- return format_step(
242
- self.action_agent(
243
- self._build_agent_prompt(question, reflections, scratchpad),
244
- image=image,
241
+ elif isinstance(self.action_agent, LMM):
242
+ return format_step(
243
+ self.action_agent(
244
+ self._build_agent_prompt(question, reflections, scratchpad),
245
+ images=[image] if image is not None else None,
246
+ )
247
+ )
248
+ elif isinstance(self.action_agent, Agent):
249
+ return format_step(
250
+ self.action_agent(
251
+ self._build_agent_prompt(question, reflections, scratchpad),
252
+ image=image,
253
+ )
245
254
  )
246
- )
247
255
 
248
256
  def prompt_reflection(
249
257
  self,
@@ -261,7 +269,7 @@ class Reflexion(Agent):
261
269
  return format_step(
262
270
  self.self_reflect_model(
263
271
  self._build_reflect_prompt(question, context, scratchpad),
264
- image=image,
272
+ images=[image] if image is not None else None,
265
273
  )
266
274
  )
267
275
 
@@ -3,7 +3,7 @@ import logging
3
3
  import sys
4
4
  import tempfile
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
7
 
8
8
  from PIL import Image
9
9
  from tabulate import tabulate
@@ -37,10 +37,10 @@ _LOGGER = logging.getLogger(__name__)
37
37
 
38
38
  def parse_json(s: str) -> Any:
39
39
  s = (
40
- s.replace(": true", ": True")
41
- .replace(": false", ": False")
42
- .replace(":true", ": True")
43
- .replace(":false", ": False")
40
+ s.replace(": True", ": true")
41
+ .replace(": False", ": false")
42
+ .replace(":True", ": true")
43
+ .replace(":False", ": false")
44
44
  .replace("```", "")
45
45
  .strip()
46
46
  )
@@ -62,6 +62,19 @@ def format_tools(tools: Dict[int, Any]) -> str:
62
62
  return tool_str
63
63
 
64
64
 
65
+ def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
66
+ usage = []
67
+ name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
68
+ for tool_res in tool_result:
69
+ if "tool_name" in tool_res:
70
+ usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))
71
+
72
+ usage_str = ""
73
+ for tool_name, tool_usage in usage:
74
+ usage_str += f"{tool_name} - {tool_usage}\n"
75
+ return usage_str
76
+
77
+
65
78
  def topological_sort(tasks: List[Dict]) -> List[Dict]:
66
79
  in_degree = {task["id"]: 0 for task in tasks}
67
80
  for task in tasks:
@@ -251,58 +264,46 @@ def self_reflect(
251
264
  tools: Dict[int, Any],
252
265
  tool_result: List[Dict],
253
266
  final_answer: str,
254
- image: Optional[Union[str, Path]] = None,
267
+ images: Optional[Sequence[Union[str, Path]]] = None,
255
268
  ) -> str:
256
269
  prompt = VISION_AGENT_REFLECTION.format(
257
270
  question=question,
258
- tools=format_tools(tools),
271
+ tools=format_tools({k: v["description"] for k, v in tools.items()}),
272
+ tool_usage=format_tool_usage(tools, tool_result),
259
273
  tool_results=str(tool_result),
260
274
  final_answer=final_answer,
261
275
  )
262
276
  if (
263
277
  issubclass(type(reflect_model), LMM)
264
- and image is not None
265
- and Path(image).suffix in [".jpg", ".jpeg", ".png"]
278
+ and images is not None
279
+ and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
266
280
  ):
267
- return reflect_model(prompt, image=image) # type: ignore
281
+ return reflect_model(prompt, images=images) # type: ignore
268
282
  return reflect_model(prompt)
269
283
 
270
284
 
271
- def parse_reflect(reflect: str) -> bool:
272
- # GPT-4V has a hard time following directions, so make the criteria less strict
273
- return (
285
+ def parse_reflect(reflect: str) -> Any:
286
+ reflect = reflect.strip()
287
+ try:
288
+ return parse_json(reflect)
289
+ except Exception:
290
+ _LOGGER.error(f"Failed parse json reflection: {reflect}")
291
+ # LMMs have a hard time following directions, so make the criteria less strict
292
+ finish = (
274
293
  "finish" in reflect.lower() and len(reflect) < 100
275
294
  ) or "finish" in reflect.lower()[-10:]
276
-
277
-
278
- def visualize_result(all_tool_results: List[Dict]) -> List[str]:
279
- image_to_data: Dict[str, Dict] = {}
280
- for tool_result in all_tool_results:
281
- if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
282
- continue
283
-
284
- parameters = tool_result["parameters"]
285
- # parameters can either be a dictionary or list, parameters can also be malformed
286
- # becaus the LLM builds them
287
- if isinstance(parameters, dict):
288
- if "image" not in parameters:
289
- continue
290
- parameters = [parameters]
291
- elif isinstance(tool_result["parameters"], list):
292
- if len(tool_result["parameters"]) < 1 or (
293
- "image" not in tool_result["parameters"][0]
294
- ):
295
- continue
296
-
297
- for param, call_result in zip(parameters, tool_result["call_results"]):
298
- # calls can fail, so we need to check if the call was successful
299
- if not isinstance(call_result, dict):
300
- continue
301
- if "bboxes" not in call_result:
302
- continue
303
-
304
- # if the call was successful, then we can add the image data
305
- image = param["image"]
295
+ return {"Finish": finish, "Reflection": reflect}
296
+
297
+
298
+ def _handle_extract_frames(
299
+ image_to_data: Dict[str, Dict], tool_result: Dict
300
+ ) -> Dict[str, Dict]:
301
+ image_to_data = image_to_data.copy()
302
+ # handle extract_frames_ case, useful if it extracts frames but doesn't do
303
+ # any following processing
304
+ for video_file_output in tool_result["call_results"]:
305
+ for frame, _ in video_file_output:
306
+ image = frame
306
307
  if image not in image_to_data:
307
308
  image_to_data[image] = {
308
309
  "bboxes": [],
@@ -310,17 +311,72 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
310
311
  "labels": [],
311
312
  "scores": [],
312
313
  }
314
+ return image_to_data
315
+
316
+
317
+ def _handle_viz_tools(
318
+ image_to_data: Dict[str, Dict], tool_result: Dict
319
+ ) -> Dict[str, Dict]:
320
+ image_to_data = image_to_data.copy()
321
+
322
+ # handle grounding_sam_ and grounding_dino_
323
+ parameters = tool_result["parameters"]
324
+ # parameters can either be a dictionary or list, parameters can also be malformed
325
+ # becaus the LLM builds them
326
+ if isinstance(parameters, dict):
327
+ if "image" not in parameters:
328
+ return image_to_data
329
+ parameters = [parameters]
330
+ elif isinstance(tool_result["parameters"], list):
331
+ if len(tool_result["parameters"]) < 1 or (
332
+ "image" not in tool_result["parameters"][0]
333
+ ):
334
+ return image_to_data
335
+
336
+ for param, call_result in zip(parameters, tool_result["call_results"]):
337
+ # calls can fail, so we need to check if the call was successful
338
+ if not isinstance(call_result, dict) or "bboxes" not in call_result:
339
+ return image_to_data
340
+
341
+ # if the call was successful, then we can add the image data
342
+ image = param["image"]
343
+ if image not in image_to_data:
344
+ image_to_data[image] = {
345
+ "bboxes": [],
346
+ "masks": [],
347
+ "labels": [],
348
+ "scores": [],
349
+ }
350
+
351
+ image_to_data[image]["bboxes"].extend(call_result["bboxes"])
352
+ image_to_data[image]["labels"].extend(call_result["labels"])
353
+ image_to_data[image]["scores"].extend(call_result["scores"])
354
+ if "masks" in call_result:
355
+ image_to_data[image]["masks"].extend(call_result["masks"])
356
+
357
+ return image_to_data
358
+
359
+
360
+ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
361
+ image_to_data: Dict[str, Dict] = {}
362
+ for tool_result in all_tool_results:
363
+ # only handle bbox/mask tools or frame extraction
364
+ if tool_result["tool_name"] not in [
365
+ "grounding_sam_",
366
+ "grounding_dino_",
367
+ "extract_frames_",
368
+ ]:
369
+ continue
313
370
 
314
- image_to_data[image]["bboxes"].extend(call_result["bboxes"])
315
- image_to_data[image]["labels"].extend(call_result["labels"])
316
- image_to_data[image]["scores"].extend(call_result["scores"])
317
- if "masks" in call_result:
318
- image_to_data[image]["masks"].extend(call_result["masks"])
371
+ if tool_result["tool_name"] == "extract_frames_":
372
+ image_to_data = _handle_extract_frames(image_to_data, tool_result)
373
+ else:
374
+ image_to_data = _handle_viz_tools(image_to_data, tool_result)
319
375
 
320
376
  visualized_images = []
321
- for image in image_to_data:
322
- image_path = Path(image)
323
- image_data = image_to_data[image]
377
+ for image_str in image_to_data:
378
+ image_path = Path(image_str)
379
+ image_data = image_to_data[image_str]
324
380
  image = overlay_masks(image_path, image_data)
325
381
  image = overlay_bboxes(image, image_data)
326
382
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
@@ -351,7 +407,7 @@ class VisionAgent(Agent):
351
407
  task_model: Optional[Union[LLM, LMM]] = None,
352
408
  answer_model: Optional[Union[LLM, LMM]] = None,
353
409
  reflect_model: Optional[Union[LLM, LMM]] = None,
354
- max_retries: int = 2,
410
+ max_retries: int = 3,
355
411
  verbose: bool = False,
356
412
  report_progress_callback: Optional[Callable[[str], None]] = None,
357
413
  ):
@@ -374,7 +430,9 @@ class VisionAgent(Agent):
374
430
  OpenAILLM(temperature=0.1) if answer_model is None else answer_model
375
431
  )
376
432
  self.reflect_model = (
377
- OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
433
+ OpenAILMM(json_mode=True, temperature=0.1)
434
+ if reflect_model is None
435
+ else reflect_model
378
436
  )
379
437
  self.max_retries = max_retries
380
438
  self.tools = TOOLS
@@ -461,20 +519,27 @@ class VisionAgent(Agent):
461
519
 
462
520
  visualized_output = visualize_result(all_tool_results)
463
521
  all_tool_results.append({"visualized_output": visualized_output})
522
+ if len(visualized_output) > 0:
523
+ reflection_images = visualized_output
524
+ elif image is not None:
525
+ reflection_images = [image]
526
+ else:
527
+ reflection_images = None
464
528
  reflection = self_reflect(
465
529
  self.reflect_model,
466
530
  question,
467
531
  self.tools,
468
532
  all_tool_results,
469
533
  final_answer,
470
- visualized_output[0] if len(visualized_output) > 0 else image,
534
+ reflection_images,
471
535
  )
472
536
  self.log_progress(f"Reflection: {reflection}")
473
- if parse_reflect(reflection):
537
+ parsed_reflection = parse_reflect(reflection)
538
+ if parsed_reflection["Finish"]:
474
539
  break
475
540
  else:
476
- reflections += "\n" + reflection
477
- # '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
541
+ reflections += "\n" + parsed_reflection["Reflection"]
542
+ # '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
478
543
  self.log_progress(
479
544
  f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
480
545
  )
@@ -1,4 +1,14 @@
1
- VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences.
1
+ VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question, the tool usage for each of the tools used and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
2
+
3
+ Please note that:
4
+ 1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like:
5
+ {{"Finish": true, "Reflection": "The agent's answer was correct."}}
6
+ 2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or if the tools were used incorrectly or the wrong tools were used.
7
+ 3. If the agent's answer was incorrect, you must diagnose the reason for failure and devise a new concise and concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
8
+ {{"Finish": false, "Reflection": "I can see from the visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
9
+ Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives.
10
+ Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}}
11
+ 4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true.
2
12
 
3
13
  User's question: {question}
4
14
 
@@ -8,6 +18,9 @@ Tools available:
8
18
  Tasks and tools used:
9
19
  {tool_results}
10
20
 
21
+ Tool's used API documentation:
22
+ {tool_usage}
23
+
11
24
  Final answer:
12
25
  {final_answer}
13
26
 
@@ -127,4 +140,5 @@ These are subtasks and their answers:
127
140
 
128
141
  This is a reflection from a previous failed attempt:
129
142
  {reflections}
143
+
130
144
  Final answer: """
vision_agent/data/data.py CHANGED
@@ -63,9 +63,9 @@ class DataStore:
63
63
 
64
64
  self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
65
65
  lambda x: (
66
- func(self.lmm.generate(prompt, image=x))
66
+ func(self.lmm.generate(prompt, images=[x]))
67
67
  if func
68
- else self.lmm.generate(prompt, image=x)
68
+ else self.lmm.generate(prompt, images=[x])
69
69
  )
70
70
  )
71
71
  return self
vision_agent/llm/llm.py CHANGED
@@ -33,7 +33,7 @@ class OpenAILLM(LLM):
33
33
 
34
34
  def __init__(
35
35
  self,
36
- model_name: str = "gpt-4-turbo-preview",
36
+ model_name: str = "gpt-4-turbo",
37
37
  api_key: Optional[str] = None,
38
38
  json_mode: bool = False,
39
39
  **kwargs: Any
vision_agent/lmm/lmm.py CHANGED
@@ -30,12 +30,16 @@ def encode_image(image: Union[str, Path]) -> str:
30
30
 
31
31
  class LMM(ABC):
32
32
  @abstractmethod
33
- def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
33
+ def generate(
34
+ self, prompt: str, images: Optional[List[Union[str, Path]]] = None
35
+ ) -> str:
34
36
  pass
35
37
 
36
38
  @abstractmethod
37
39
  def chat(
38
- self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
40
+ self,
41
+ chat: List[Dict[str, str]],
42
+ images: Optional[List[Union[str, Path]]] = None,
39
43
  ) -> str:
40
44
  pass
41
45
 
@@ -43,7 +47,7 @@ class LMM(ABC):
43
47
  def __call__(
44
48
  self,
45
49
  input: Union[str, List[Dict[str, str]]],
46
- image: Optional[Union[str, Path]] = None,
50
+ images: Optional[List[Union[str, Path]]] = None,
47
51
  ) -> str:
48
52
  pass
49
53
 
@@ -57,27 +61,29 @@ class LLaVALMM(LMM):
57
61
  def __call__(
58
62
  self,
59
63
  input: Union[str, List[Dict[str, str]]],
60
- image: Optional[Union[str, Path]] = None,
64
+ images: Optional[List[Union[str, Path]]] = None,
61
65
  ) -> str:
62
66
  if isinstance(input, str):
63
- return self.generate(input, image)
64
- return self.chat(input, image)
67
+ return self.generate(input, images)
68
+ return self.chat(input, images)
65
69
 
66
70
  def chat(
67
- self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
71
+ self,
72
+ chat: List[Dict[str, str]],
73
+ images: Optional[List[Union[str, Path]]] = None,
68
74
  ) -> str:
69
75
  raise NotImplementedError("Chat not supported for LLaVA")
70
76
 
71
77
  def generate(
72
78
  self,
73
79
  prompt: str,
74
- image: Optional[Union[str, Path]] = None,
80
+ images: Optional[List[Union[str, Path]]] = None,
75
81
  temperature: float = 0.1,
76
82
  max_new_tokens: int = 1500,
77
83
  ) -> str:
78
84
  data = {"prompt": prompt}
79
- if image:
80
- data["image"] = encode_image(image)
85
+ if images and len(images) > 0:
86
+ data["image"] = encode_image(images[0])
81
87
  data["temperature"] = temperature # type: ignore
82
88
  data["max_new_tokens"] = max_new_tokens # type: ignore
83
89
  res = requests.post(
@@ -99,9 +105,10 @@ class OpenAILMM(LMM):
99
105
 
100
106
  def __init__(
101
107
  self,
102
- model_name: str = "gpt-4-vision-preview",
108
+ model_name: str = "gpt-4-turbo",
103
109
  api_key: Optional[str] = None,
104
110
  max_tokens: int = 1024,
111
+ json_mode: bool = False,
105
112
  **kwargs: Any,
106
113
  ):
107
114
  if not api_key:
@@ -111,20 +118,25 @@ class OpenAILMM(LMM):
111
118
 
112
119
  self.client = OpenAI(api_key=api_key)
113
120
  self.model_name = model_name
114
- self.max_tokens = max_tokens
121
+ if "max_tokens" not in kwargs:
122
+ kwargs["max_tokens"] = max_tokens
123
+ if json_mode:
124
+ kwargs["response_format"] = {"type": "json_object"}
115
125
  self.kwargs = kwargs
116
126
 
117
127
  def __call__(
118
128
  self,
119
129
  input: Union[str, List[Dict[str, str]]],
120
- image: Optional[Union[str, Path]] = None,
130
+ images: Optional[List[Union[str, Path]]] = None,
121
131
  ) -> str:
122
132
  if isinstance(input, str):
123
- return self.generate(input, image)
124
- return self.chat(input, image)
133
+ return self.generate(input, images)
134
+ return self.chat(input, images)
125
135
 
126
136
  def chat(
127
- self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
137
+ self,
138
+ chat: List[Dict[str, str]],
139
+ images: Optional[List[Union[str, Path]]] = None,
128
140
  ) -> str:
129
141
  fixed_chat = []
130
142
  for c in chat:
@@ -132,33 +144,38 @@ class OpenAILMM(LMM):
132
144
  fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
133
145
  fixed_chat.append(fixed_c)
134
146
 
135
- if image:
136
- extension = Path(image).suffix
137
- if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
138
- extension = "jpg"
139
- elif extension.lower() == ".png":
140
- extension = "png"
141
- else:
142
- raise ValueError(f"Unsupported image extension: {extension}")
143
-
144
- encoded_image = encode_image(image)
145
- fixed_chat[0]["content"].append( # type: ignore
146
- {
147
- "type": "image_url",
148
- "image_url": {
149
- "url": f"data:image/{extension};base64,{encoded_image}",
150
- "detail": "low",
147
+ if images and len(images) > 0:
148
+ for image in images:
149
+ extension = Path(image).suffix
150
+ if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
151
+ extension = "jpg"
152
+ elif extension.lower() == ".png":
153
+ extension = "png"
154
+ else:
155
+ raise ValueError(f"Unsupported image extension: {extension}")
156
+
157
+ encoded_image = encode_image(image)
158
+ fixed_chat[0]["content"].append( # type: ignore
159
+ {
160
+ "type": "image_url",
161
+ "image_url": {
162
+ "url": f"data:image/{extension};base64,{encoded_image}",
163
+ "detail": "low",
164
+ },
151
165
  },
152
- },
153
- )
166
+ )
154
167
 
155
168
  response = self.client.chat.completions.create(
156
- model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
169
+ model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
157
170
  )
158
171
 
159
172
  return cast(str, response.choices[0].message.content)
160
173
 
161
- def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
174
+ def generate(
175
+ self,
176
+ prompt: str,
177
+ images: Optional[List[Union[str, Path]]] = None,
178
+ ) -> str:
162
179
  message: List[Dict[str, Any]] = [
163
180
  {
164
181
  "role": "user",
@@ -167,21 +184,22 @@ class OpenAILMM(LMM):
167
184
  ],
168
185
  }
169
186
  ]
170
- if image:
171
- extension = Path(image).suffix
172
- encoded_image = encode_image(image)
173
- message[0]["content"].append(
174
- {
175
- "type": "image_url",
176
- "image_url": {
177
- "url": f"data:image/{extension};base64,{encoded_image}",
178
- "detail": "low",
187
+ if images and len(images) > 0:
188
+ for image in images:
189
+ extension = Path(image).suffix
190
+ encoded_image = encode_image(image)
191
+ message[0]["content"].append(
192
+ {
193
+ "type": "image_url",
194
+ "image_url": {
195
+ "url": f"data:image/{extension};base64,{encoded_image}",
196
+ "detail": "low",
197
+ },
179
198
  },
180
- },
181
- )
199
+ )
182
200
 
183
201
  response = self.client.chat.completions.create(
184
- model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
202
+ model=self.model_name, messages=message, **self.kwargs # type: ignore
185
203
  )
186
204
  return cast(str, response.choices[0].message.content)
187
205
 
@@ -1,10 +1,10 @@
1
1
  from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
2
- from .tools import (
2
+ from .tools import ( # Counter,
3
3
  CLIP,
4
4
  TOOLS,
5
5
  BboxArea,
6
6
  BboxIoU,
7
- Counter,
7
+ BoxDistance,
8
8
  Crop,
9
9
  ExtractFrames,
10
10
  GroundingDINO,
@@ -1,7 +1,6 @@
1
1
  import logging
2
2
  import tempfile
3
3
  from abc import ABC
4
- from collections import Counter as CounterClass
5
4
  from pathlib import Path
6
5
  from typing import Any, Dict, List, Tuple, Union, cast
7
6
 
@@ -396,33 +395,6 @@ class AgentGroundingSAM(GroundingSAM):
396
395
  return rets
397
396
 
398
397
 
399
- class Counter(Tool):
400
- r"""Counter detects and counts the number of objects in an image given an input such as a category name or referring expression."""
401
-
402
- name = "counter_"
403
- description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression. It returns a dictionary containing the labels and their counts."
404
- usage = {
405
- "required_parameters": [
406
- {"name": "prompt", "type": "str"},
407
- {"name": "image", "type": "str"},
408
- ],
409
- "examples": [
410
- {
411
- "scenario": "Can you count the number of cars in this image? Image name image.jpg",
412
- "parameters": {"prompt": "car", "image": "image.jpg"},
413
- },
414
- {
415
- "scenario": "Can you count the number of people? Image name: people.png",
416
- "parameters": {"prompt": "person", "image": "people.png"},
417
- },
418
- ],
419
- }
420
-
421
- def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
422
- resp = GroundingDINO()(prompt, image)
423
- return dict(CounterClass(resp["labels"]))
424
-
425
-
426
398
  class Crop(Tool):
427
399
  r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
428
400
 
@@ -573,11 +545,42 @@ class SegIoU(Tool):
573
545
  return cast(float, round(iou, 2))
574
546
 
575
547
 
548
+ class BoxDistance(Tool):
549
+ name = "box_distance_"
550
+ description = (
551
+ "'box_distance_' returns the minimum distance between two bounding boxes."
552
+ )
553
+ usage = {
554
+ "required_parameters": [
555
+ {"name": "bbox1", "type": "List[int]"},
556
+ {"name": "bbox2", "type": "List[int]"},
557
+ ],
558
+ "examples": [
559
+ {
560
+ "scenario": "If you want to calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
561
+ "parameters": {
562
+ "bbox1": [0.2, 0.21, 0.34, 0.42],
563
+ "bbox2": [0.3, 0.31, 0.44, 0.52],
564
+ },
565
+ }
566
+ ],
567
+ }
568
+
569
+ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
570
+ x11, y11, x12, y12 = bbox1
571
+ x21, y21, x22, y22 = bbox2
572
+
573
+ horizontal_dist = np.max([0, x21 - x12, x11 - x22])
574
+ vertical_dist = np.max([0, y21 - y12, y11 - y22])
575
+
576
+ return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2))
577
+
578
+
576
579
  class ExtractFrames(Tool):
577
580
  r"""Extract frames from a video."""
578
581
 
579
582
  name = "extract_frames_"
580
- description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
583
+ description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
581
584
  usage = {
582
585
  "required_parameters": [{"name": "video_uri", "type": "str"}],
583
586
  "examples": [
@@ -650,12 +653,12 @@ TOOLS = {
650
653
  GroundingDINO,
651
654
  AgentGroundingSAM,
652
655
  ExtractFrames,
653
- Counter,
654
656
  Crop,
655
657
  BboxArea,
656
658
  SegArea,
657
659
  BboxIoU,
658
660
  SegIoU,
661
+ BoxDistance,
659
662
  Calculator,
660
663
  ]
661
664
  )
@@ -15,7 +15,7 @@ _CLIP_LENGTH = 30.0
15
15
 
16
16
 
17
17
  def extract_frames_from_video(
18
- video_uri: str, fps: int = 2, motion_detection_threshold: float = 0.06
18
+ video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
19
19
  ) -> List[Tuple[np.ndarray, float]]:
20
20
  """Extract frames from a video
21
21
 
@@ -25,7 +25,8 @@ def extract_frames_from_video(
25
25
  motion_detection_threshold: The threshold to detect motion between
26
26
  changes/frames. A value between 0-1, which represents the percentage change
27
27
  required for the frames to be considered in motion. For example, a lower
28
- value means more frames will be extracted.
28
+ value means more frames will be extracted. A non-positive value will disable
29
+ motion detection and extract all frames.
29
30
 
30
31
  Returns:
31
32
  a list of tuples containing the extracted frame and the timestamp in seconds.
@@ -119,18 +120,19 @@ def _extract_frames_by_clip(
119
120
  total=processable_frames, desc=f"Extracting frames from clip {start}-{end}"
120
121
  )
121
122
  for i, frame in enumerate(clip.iter_frames(fps=fps, dtype="uint8")):
122
- curr_processed_frame = _preprocess_frame(frame)
123
123
  total_count += 1
124
124
  pbar.update(1)
125
- # Skip the frame if it is similar to the previous one
126
- if prev_processed_frame is not None and _similar_frame(
127
- prev_processed_frame,
128
- curr_processed_frame,
129
- threshold=motion_detection_threshold,
130
- ):
131
- skipped_count += 1
132
- continue
133
- prev_processed_frame = curr_processed_frame
125
+ if motion_detection_threshold > 0:
126
+ curr_processed_frame = _preprocess_frame(frame)
127
+ # Skip the frame if it is similar to the previous one
128
+ if prev_processed_frame is not None and _similar_frame(
129
+ prev_processed_frame,
130
+ curr_processed_frame,
131
+ threshold=motion_detection_threshold,
132
+ ):
133
+ skipped_count += 1
134
+ continue
135
+ prev_processed_frame = curr_processed_frame
134
136
  ts = round(clip.reader.pos / source_fps, 3)
135
137
  frames.append((frame, ts))
136
138
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -2,28 +2,28 @@ vision_agent/__init__.py,sha256=wD1cssVTAJ55uTViNfBGooqJUV0p9fmVAuTMHHrmUBU,229
2
2
  vision_agent/agent/__init__.py,sha256=B4JVrbY4IRVCJfjmrgvcp7h1mTUEk8MZvL0Zmej4Ka0,127
3
3
  vision_agent/agent/agent.py,sha256=X7kON-g9ePUKumCDaYfQNBX_MEFE-ax5PnRp7-Cc5Wo,529
4
4
  vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMVg,11511
5
- vision_agent/agent/easytool_prompts.py,sha256=dYzWa_RaiaFSQ-CowoQOcFmjZtBTTljRyA809bLgrvU,4519
6
- vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
5
+ vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
6
+ vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=nHmfr-OuMfdH0N8gECXLzTAgRmTx9cYe5_pnQj-HnBE,19764
9
- vision_agent/agent/vision_agent_prompts.py,sha256=dPg0mLVK_fGJpYK2xXGhm-zuXX1KVZW_zFXyYsspUz8,6567
8
+ vision_agent/agent/vision_agent.py,sha256=4-milD0iSY_vKdpAIctba04Ak_In5tMBE8gATdaGIr0,22019
9
+ vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
10
10
  vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
- vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
11
+ vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
12
12
  vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
13
13
  vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
14
14
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
16
16
  vision_agent/image_utils.py,sha256=hFdPoRmeVU5jErFr5xaagMQ6Wy7Xbw8H8HXuLGdJIAM,4786
17
17
  vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
18
- vision_agent/llm/llm.py,sha256=tgL6ZtuwZKuxSNiCxJCuP2ETjNMrosdgxXkZJb0_00E,5024
18
+ vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
19
19
  vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
20
- vision_agent/lmm/lmm.py,sha256=LxwxCArp7DfnPbjf_Gl55xBxPwo2Qx8eDp1gCnGYSO0,9535
21
- vision_agent/tools/__init__.py,sha256=OEqEysxm5wnnOD73NKNCUggALB72GEmVg9FNsEkSBtA,253
20
+ vision_agent/lmm/lmm.py,sha256=1E7e_S_0fOKnf6mSsEdkXvsIjGmhBGl5XW4By2jvhbY,10045
21
+ vision_agent/tools/__init__.py,sha256=lKv90gLu-mNp4uyGtJ8AUG-73xKwFEugZpe0atpsscA,269
22
22
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
23
- vision_agent/tools/tools.py,sha256=Qsqe8X6VjB0EMWhyKJ5EMPyLIc_d5Vtlw4ugV2FB_Ks,25589
24
- vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
23
+ vision_agent/tools/tools.py,sha256=EK9HauKZ1gq795wBZNER6-8PiDTNZwJ1sXYhDeplDZ0,25410
24
+ vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
25
25
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
26
- vision_agent-0.1.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
- vision_agent-0.1.1.dist-info/METADATA,sha256=rWMocnnZwuRhd3xIGyQUzDbsndVASBSu2jvAqt-3Odc,6233
28
- vision_agent-0.1.1.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
- vision_agent-0.1.1.dist-info/RECORD,,
26
+ vision_agent-0.1.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
+ vision_agent-0.1.3.dist-info/METADATA,sha256=iBoN2GBvALl6XxhxRo4o9WaqLgI-UAobSymuZ1RHd9o,6233
28
+ vision_agent-0.1.3.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
+ vision_agent-0.1.3.dist-info/RECORD,,