vision-agent 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,10 +37,10 @@ _LOGGER = logging.getLogger(__name__)
37
37
 
38
38
  def parse_json(s: str) -> Any:
39
39
  s = (
40
- s.replace(": true", ": True")
41
- .replace(": false", ": False")
42
- .replace(":true", ": True")
43
- .replace(":false", ": False")
40
+ s.replace(": True", ": true")
41
+ .replace(": False", ": false")
42
+ .replace(":True", ": true")
43
+ .replace(":False", ": false")
44
44
  .replace("```", "")
45
45
  .strip()
46
46
  )
@@ -62,6 +62,19 @@ def format_tools(tools: Dict[int, Any]) -> str:
62
62
  return tool_str
63
63
 
64
64
 
65
+ def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
66
+ usage = []
67
+ name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
68
+ for tool_res in tool_result:
69
+ if "tool_name" in tool_res:
70
+ usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))
71
+
72
+ usage_str = ""
73
+ for tool_name, tool_usage in usage:
74
+ usage_str += f"{tool_name} - {tool_usage}\n"
75
+ return usage_str
76
+
77
+
65
78
  def topological_sort(tasks: List[Dict]) -> List[Dict]:
66
79
  in_degree = {task["id"]: 0 for task in tasks}
67
80
  for task in tasks:
@@ -255,7 +268,8 @@ def self_reflect(
255
268
  ) -> str:
256
269
  prompt = VISION_AGENT_REFLECTION.format(
257
270
  question=question,
258
- tools=format_tools(tools),
271
+ tools=format_tools({k: v["description"] for k, v in tools.items()}),
272
+ tool_usage=format_tool_usage(tools, tool_result),
259
273
  tool_results=str(tool_result),
260
274
  final_answer=final_answer,
261
275
  )
@@ -268,41 +282,28 @@ def self_reflect(
268
282
  return reflect_model(prompt)
269
283
 
270
284
 
271
- def parse_reflect(reflect: str) -> bool:
272
- # GPT-4V has a hard time following directions, so make the criteria less strict
273
- return (
285
+ def parse_reflect(reflect: str) -> Any:
286
+ reflect = reflect.strip()
287
+ try:
288
+ return parse_json(reflect)
289
+ except Exception:
290
+ _LOGGER.error(f"Failed parse json reflection: {reflect}")
291
+ # LMMs have a hard time following directions, so make the criteria less strict
292
+ finish = (
274
293
  "finish" in reflect.lower() and len(reflect) < 100
275
294
  ) or "finish" in reflect.lower()[-10:]
276
-
277
-
278
- def visualize_result(all_tool_results: List[Dict]) -> List[str]:
279
- image_to_data: Dict[str, Dict] = {}
280
- for tool_result in all_tool_results:
281
- if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
282
- continue
283
-
284
- parameters = tool_result["parameters"]
285
- # parameters can either be a dictionary or list, parameters can also be malformed
286
- # becaus the LLM builds them
287
- if isinstance(parameters, dict):
288
- if "image" not in parameters:
289
- continue
290
- parameters = [parameters]
291
- elif isinstance(tool_result["parameters"], list):
292
- if len(tool_result["parameters"]) < 1 or (
293
- "image" not in tool_result["parameters"][0]
294
- ):
295
- continue
296
-
297
- for param, call_result in zip(parameters, tool_result["call_results"]):
298
- # calls can fail, so we need to check if the call was successful
299
- if not isinstance(call_result, dict):
300
- continue
301
- if "bboxes" not in call_result:
302
- continue
303
-
304
- # if the call was successful, then we can add the image data
305
- image = param["image"]
295
+ return {"Finish": finish, "Reflection": reflect}
296
+
297
+
298
+ def _handle_extract_frames(
299
+ image_to_data: Dict[str, Dict], tool_result: Dict
300
+ ) -> Dict[str, Dict]:
301
+ image_to_data = image_to_data.copy()
302
+ # handle extract_frames_ case, useful if it extracts frames but doesn't do
303
+ # any following processing
304
+ for video_file_output in tool_result["call_results"]:
305
+ for frame, _ in video_file_output:
306
+ image = frame
306
307
  if image not in image_to_data:
307
308
  image_to_data[image] = {
308
309
  "bboxes": [],
@@ -310,17 +311,72 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
310
311
  "labels": [],
311
312
  "scores": [],
312
313
  }
314
+ return image_to_data
315
+
316
+
317
+ def _handle_viz_tools(
318
+ image_to_data: Dict[str, Dict], tool_result: Dict
319
+ ) -> Dict[str, Dict]:
320
+ image_to_data = image_to_data.copy()
321
+
322
+ # handle grounding_sam_ and grounding_dino_
323
+ parameters = tool_result["parameters"]
324
+ # parameters can either be a dictionary or list, parameters can also be malformed
325
+ # becaus the LLM builds them
326
+ if isinstance(parameters, dict):
327
+ if "image" not in parameters:
328
+ return image_to_data
329
+ parameters = [parameters]
330
+ elif isinstance(tool_result["parameters"], list):
331
+ if len(tool_result["parameters"]) < 1 or (
332
+ "image" not in tool_result["parameters"][0]
333
+ ):
334
+ return image_to_data
335
+
336
+ for param, call_result in zip(parameters, tool_result["call_results"]):
337
+ # calls can fail, so we need to check if the call was successful
338
+ if not isinstance(call_result, dict) or "bboxes" not in call_result:
339
+ return image_to_data
340
+
341
+ # if the call was successful, then we can add the image data
342
+ image = param["image"]
343
+ if image not in image_to_data:
344
+ image_to_data[image] = {
345
+ "bboxes": [],
346
+ "masks": [],
347
+ "labels": [],
348
+ "scores": [],
349
+ }
350
+
351
+ image_to_data[image]["bboxes"].extend(call_result["bboxes"])
352
+ image_to_data[image]["labels"].extend(call_result["labels"])
353
+ image_to_data[image]["scores"].extend(call_result["scores"])
354
+ if "masks" in call_result:
355
+ image_to_data[image]["masks"].extend(call_result["masks"])
356
+
357
+ return image_to_data
358
+
313
359
 
314
- image_to_data[image]["bboxes"].extend(call_result["bboxes"])
315
- image_to_data[image]["labels"].extend(call_result["labels"])
316
- image_to_data[image]["scores"].extend(call_result["scores"])
317
- if "masks" in call_result:
318
- image_to_data[image]["masks"].extend(call_result["masks"])
360
+ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
361
+ image_to_data: Dict[str, Dict] = {}
362
+ for tool_result in all_tool_results:
363
+ # only handle bbox/mask tools or frame extraction
364
+ if tool_result["tool_name"] not in [
365
+ "grounding_sam_",
366
+ "grounding_dino_",
367
+ "extract_frames_",
368
+ ]:
369
+ continue
370
+
371
+ if tool_result["tool_name"] == "extract_frames_":
372
+ image_to_data = _handle_extract_frames(image_to_data, tool_result)
373
+ else:
374
+ image_to_data = _handle_viz_tools(image_to_data, tool_result)
319
375
 
320
376
  visualized_images = []
321
- for image in image_to_data:
322
- image_path = Path(image)
323
- image_data = image_to_data[image]
377
+ for image_str in image_to_data:
378
+ image_path = Path(image_str)
379
+ image_data = image_to_data[image_str]
324
380
  image = overlay_masks(image_path, image_data)
325
381
  image = overlay_bboxes(image, image_data)
326
382
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
@@ -374,7 +430,9 @@ class VisionAgent(Agent):
374
430
  OpenAILLM(temperature=0.1) if answer_model is None else answer_model
375
431
  )
376
432
  self.reflect_model = (
377
- OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
433
+ OpenAILMM(json_mode=True, temperature=0.1)
434
+ if reflect_model is None
435
+ else reflect_model
378
436
  )
379
437
  self.max_retries = max_retries
380
438
  self.tools = TOOLS
@@ -470,11 +528,12 @@ class VisionAgent(Agent):
470
528
  visualized_output[0] if len(visualized_output) > 0 else image,
471
529
  )
472
530
  self.log_progress(f"Reflection: {reflection}")
473
- if parse_reflect(reflection):
531
+ parsed_reflection = parse_reflect(reflection)
532
+ if parsed_reflection["Finish"]:
474
533
  break
475
534
  else:
476
- reflections += "\n" + reflection
477
- # '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
535
+ reflections += "\n" + parsed_reflection["Reflection"]
536
+ # '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
478
537
  self.log_progress(
479
538
  f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
480
539
  )
@@ -1,4 +1,14 @@
1
- VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences.
1
+ VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
2
+
3
+ Please note that:
4
+ 1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like:
5
+ {{"Finish": true, "Reflection": "The agent's answer was correct."}}
6
+ 2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly.
7
+ 3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
8
+ {{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
9
+ Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives.
10
+ Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}}
11
+ 4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true.
2
12
 
3
13
  User's question: {question}
4
14
 
@@ -8,6 +18,9 @@ Tools available:
8
18
  Tasks and tools used:
9
19
  {tool_results}
10
20
 
21
+ Tool's used API documentation:
22
+ {tool_usage}
23
+
11
24
  Final answer:
12
25
  {final_answer}
13
26
 
vision_agent/llm/llm.py CHANGED
@@ -33,7 +33,7 @@ class OpenAILLM(LLM):
33
33
 
34
34
  def __init__(
35
35
  self,
36
- model_name: str = "gpt-4-turbo-preview",
36
+ model_name: str = "gpt-4-turbo",
37
37
  api_key: Optional[str] = None,
38
38
  json_mode: bool = False,
39
39
  **kwargs: Any
vision_agent/lmm/lmm.py CHANGED
@@ -99,9 +99,10 @@ class OpenAILMM(LMM):
99
99
 
100
100
  def __init__(
101
101
  self,
102
- model_name: str = "gpt-4-vision-preview",
102
+ model_name: str = "gpt-4-turbo",
103
103
  api_key: Optional[str] = None,
104
104
  max_tokens: int = 1024,
105
+ json_mode: bool = False,
105
106
  **kwargs: Any,
106
107
  ):
107
108
  if not api_key:
@@ -111,7 +112,10 @@ class OpenAILMM(LMM):
111
112
 
112
113
  self.client = OpenAI(api_key=api_key)
113
114
  self.model_name = model_name
114
- self.max_tokens = max_tokens
115
+ if "max_tokens" not in kwargs:
116
+ kwargs["max_tokens"] = max_tokens
117
+ if json_mode:
118
+ kwargs["response_format"] = {"type": "json_object"}
115
119
  self.kwargs = kwargs
116
120
 
117
121
  def __call__(
@@ -153,7 +157,7 @@ class OpenAILMM(LMM):
153
157
  )
154
158
 
155
159
  response = self.client.chat.completions.create(
156
- model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
160
+ model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
157
161
  )
158
162
 
159
163
  return cast(str, response.choices[0].message.content)
@@ -181,7 +185,7 @@ class OpenAILMM(LMM):
181
185
  )
182
186
 
183
187
  response = self.client.chat.completions.create(
184
- model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
188
+ model=self.model_name, messages=message, **self.kwargs # type: ignore
185
189
  )
186
190
  return cast(str, response.choices[0].message.content)
187
191
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -5,8 +5,8 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
5
5
  vision_agent/agent/easytool_prompts.py,sha256=dYzWa_RaiaFSQ-CowoQOcFmjZtBTTljRyA809bLgrvU,4519
6
6
  vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=nHmfr-OuMfdH0N8gECXLzTAgRmTx9cYe5_pnQj-HnBE,19764
9
- vision_agent/agent/vision_agent_prompts.py,sha256=dPg0mLVK_fGJpYK2xXGhm-zuXX1KVZW_zFXyYsspUz8,6567
8
+ vision_agent/agent/vision_agent.py,sha256=_xh3v7DaeH3r5JLeXtCvDbQgogGRvpmqH3dAW7ChA1E,21759
9
+ vision_agent/agent/vision_agent_prompts.py,sha256=JC43AB0ZnL8jQW9LYe-5mTeEJmH0w-SuH9YmGQxf1eM,7311
10
10
  vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
11
  vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
12
12
  vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
@@ -15,15 +15,15 @@ vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuF
15
15
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
16
16
  vision_agent/image_utils.py,sha256=hFdPoRmeVU5jErFr5xaagMQ6Wy7Xbw8H8HXuLGdJIAM,4786
17
17
  vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
18
- vision_agent/llm/llm.py,sha256=tgL6ZtuwZKuxSNiCxJCuP2ETjNMrosdgxXkZJb0_00E,5024
18
+ vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
19
19
  vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
20
- vision_agent/lmm/lmm.py,sha256=LxwxCArp7DfnPbjf_Gl55xBxPwo2Qx8eDp1gCnGYSO0,9535
20
+ vision_agent/lmm/lmm.py,sha256=qDdy_9Q9wRjJ9ZUfqB8zfjhVIgITgjF7K4hYaTAoPCI,9637
21
21
  vision_agent/tools/__init__.py,sha256=OEqEysxm5wnnOD73NKNCUggALB72GEmVg9FNsEkSBtA,253
22
22
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
23
23
  vision_agent/tools/tools.py,sha256=Qsqe8X6VjB0EMWhyKJ5EMPyLIc_d5Vtlw4ugV2FB_Ks,25589
24
24
  vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
25
25
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
26
- vision_agent-0.1.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
- vision_agent-0.1.1.dist-info/METADATA,sha256=rWMocnnZwuRhd3xIGyQUzDbsndVASBSu2jvAqt-3Odc,6233
28
- vision_agent-0.1.1.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
- vision_agent-0.1.1.dist-info/RECORD,,
26
+ vision_agent-0.1.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
+ vision_agent-0.1.2.dist-info/METADATA,sha256=6AP0Z9G4l15uCcfBGhUfHV1AnP4lwXQuey7uH-QuvlU,6233
28
+ vision_agent-0.1.2.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
+ vision_agent-0.1.2.dist-info/RECORD,,