vision-agent 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  from .agent import Agent
2
+ from .agent_coder import AgentCoder
2
3
  from .easytool import EasyTool
3
4
  from .reflexion import Reflexion
4
5
  from .vision_agent import VisionAgent
@@ -0,0 +1,170 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from vision_agent.agent import Agent
9
+ from vision_agent.llm import LLM, OpenAILLM
10
+ from vision_agent.lmm import LMM, OpenAILMM
11
+ from vision_agent.tools.tools_v2 import TOOLS_DOCSTRING, UTILITIES_DOCSTRING
12
+
13
+ from .agent_coder_prompts import DEBUG, FIX_BUG, PROGRAM, TEST, VISUAL_TEST
14
+ from .execution import IMPORT_HELPER, check_correctness
15
+
16
+ logging.basicConfig(stream=sys.stdout)
17
+ _LOGGER = logging.getLogger(__name__)
18
+
19
+
20
+ def write_tests(question: str, code: str, model: LLM) -> str:
21
+ prompt = TEST.format(
22
+ question=question,
23
+ code=code,
24
+ )
25
+ completion = model(prompt)
26
+ return preprocess_data(completion)
27
+
28
+
29
+ def preprocess_data(code: str) -> str:
30
+ if "```python" in code:
31
+ code = code[code.find("```python") + len("```python") :]
32
+ code = code[: code.find("```")]
33
+ return code
34
+
35
+
36
+ def parse_file_name(s: str) -> str:
37
+ # We only output png files
38
+ return "".join([p for p in s.split(" ") if p.endswith(".png")])
39
+
40
+
41
+ def write_program(question: str, feedback: str, model: LLM) -> str:
42
+ prompt = PROGRAM.format(
43
+ docstring=TOOLS_DOCSTRING, question=question, feedback=feedback
44
+ )
45
+ completion = model(prompt)
46
+ return preprocess_data(completion)
47
+
48
+
49
+ def write_debug(question: str, code: str, feedback: str, model: LLM) -> str:
50
+ prompt = DEBUG.format(
51
+ docstring=UTILITIES_DOCSTRING,
52
+ code=code,
53
+ question=question,
54
+ feedback=feedback,
55
+ )
56
+ completion = model(prompt)
57
+ return preprocess_data(completion)
58
+
59
+
60
+ def execute_tests(code: str, tests: str) -> Dict[str, Union[str, bool]]:
61
+ full_code = f"{IMPORT_HELPER}\n{code}\n{tests}"
62
+ return check_correctness(full_code, 20.0)
63
+
64
+
65
+ def run_visual_tests(
66
+ question: str, code: str, viz_file: str, feedback: str, model: LMM
67
+ ) -> Dict[str, Union[str, bool]]:
68
+ prompt = VISUAL_TEST.format(
69
+ docstring=TOOLS_DOCSTRING,
70
+ code=code,
71
+ question=question,
72
+ feedback=feedback,
73
+ )
74
+ completion = model(prompt, images=[viz_file])
75
+ # type is from the prompt
76
+ return json.loads(completion) # type: ignore
77
+
78
+
79
+ def fix_bugs(code: str, tests: str, result: str, feedback: str, model: LLM) -> str:
80
+ prompt = FIX_BUG.format(completion=code, test_case=tests, result=result)
81
+ completion = model(prompt)
82
+ return preprocess_data(completion)
83
+
84
+
85
+ class AgentCoder(Agent):
86
+ """AgentCoder is based off of the AgentCoder paper https://arxiv.org/abs/2312.13010
87
+ and it's open source code https://github.com/huangd1999/AgentCoder with some key
88
+ differences. AgentCoder comprises of 3 components: a coder agent, a tester agent,
89
+ and an executor. The tester agents writes code to test the code written by the coder
90
+ agent, but in our case because we are solving a vision task it's difficult to write
91
+ testing code. We instead have the tester agent write code to visualize the output
92
+ of the code written by the coder agent. If the code fails, we pass it back to the
93
+ coder agent to fix the bug, if it succeeds we pass it to a visual tester agent, which
94
+ is an LMM model like GPT4V, to visually inspect the output and make sure it looks
95
+ good."""
96
+
97
+ def __init__(
98
+ self,
99
+ coder_agent: Optional[LLM] = None,
100
+ tester_agent: Optional[LLM] = None,
101
+ visual_tester_agent: Optional[LMM] = None,
102
+ verbose: bool = False,
103
+ ) -> None:
104
+ self.coder_agent = (
105
+ OpenAILLM(temperature=0.1) if coder_agent is None else coder_agent
106
+ )
107
+ self.tester_agent = (
108
+ OpenAILLM(temperature=0.1) if tester_agent is None else tester_agent
109
+ )
110
+ self.visual_tester_agent = (
111
+ OpenAILMM(temperature=0.1, json_mode=True)
112
+ if visual_tester_agent is None
113
+ else visual_tester_agent
114
+ )
115
+ self.max_turns = 3
116
+ if verbose:
117
+ _LOGGER.setLevel(logging.INFO)
118
+
119
+ def __call__(
120
+ self,
121
+ input: Union[List[Dict[str, str]], str],
122
+ image: Optional[Union[str, Path]] = None,
123
+ ) -> str:
124
+ if isinstance(input, str):
125
+ input = [{"role": "user", "content": input}]
126
+ return self.chat(input, image)
127
+
128
+ def chat(
129
+ self,
130
+ input: List[Dict[str, str]],
131
+ image: Optional[Union[str, Path]] = None,
132
+ ) -> str:
133
+ question = input[0]["content"]
134
+ if image:
135
+ question += f" Input file path: {os.path.abspath(image)}"
136
+
137
+ code = ""
138
+ feedback = ""
139
+ for _ in range(self.max_turns):
140
+ code = write_program(question, feedback, self.coder_agent)
141
+ _LOGGER.info(f"code:\n{code}")
142
+ debug = write_debug(question, code, feedback, self.tester_agent)
143
+ _LOGGER.info(f"debug:\n{debug}")
144
+ results = execute_tests(code, debug)
145
+ _LOGGER.info(
146
+ f"execution results: passed: {results['passed']}\n{results['result']}"
147
+ )
148
+
149
+ if not results["passed"]:
150
+ code = fix_bugs(
151
+ code, debug, results["result"].strip(), feedback, self.coder_agent # type: ignore
152
+ )
153
+ _LOGGER.info(f"fixed code:\n{code}")
154
+ else:
155
+ # TODO: Sometimes it prints nothing, so we need to handle that case
156
+ # TODO: The visual agent reflection does not work very well, needs more testing
157
+ # viz_test_results = run_visual_tests(
158
+ # question, code, parse_file_name(results["result"].strip()), feedback, self.visual_tester_agent
159
+ # )
160
+ # _LOGGER.info(f"visual test results:\n{viz_test_results}")
161
+ # if viz_test_results["finished"]:
162
+ # return f"{IMPORT_HELPER}\n{code}"
163
+ # feedback += f"\n{viz_test_results['feedback']}"
164
+
165
+ return f"{IMPORT_HELPER}\n{code}"
166
+
167
+ return f"{IMPORT_HELPER}\n{code}"
168
+
169
+ def log_progress(self, description: str) -> None:
170
+ _LOGGER.info(description)
@@ -0,0 +1,135 @@
1
+ PROGRAM = """
2
+ **Role**: You are a software programmer.
3
+
4
+ **Task**: As a programmer, you are required to complete the function. Use a Chain-of-Thought approach to break down the problem, create pseudocode, and then write the code in Python language. Ensure that your code is efficient, readable, and well-commented. Return the requested information from the function you create.
5
+
6
+ **Documentation**:
7
+ This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task, you do not need to worry about defining them or importing them and can assume they are available to you.
8
+ {docstring}
9
+
10
+ **Input Code Snippet**:
11
+ ```python
12
+ def execute(image_path: str):
13
+ # Your code here
14
+ ```
15
+
16
+ **User Instructions**:
17
+ {question}
18
+
19
+ **Previous Feedback**:
20
+ {feedback}
21
+
22
+ **Instructions**:
23
+ 1. **Understand and Clarify**: Make sure you understand the task.
24
+ 2. **Algorithm/Method Selection**: Decide on the most efficient way.
25
+ 3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode.
26
+ 4. **Code Generation**: Translate your pseudocode into executable Python code.
27
+ """
28
+
29
+ DEBUG = """
30
+ **Role**: You are a software programmer.
31
+
32
+ **Task**: Your task is to run the `execute` function and either print the output or print a file name containing visualized output for another agent to examine. The other agent will then use your output, either the printed return value of the function or the visualized output as a file, to determine if `execute` is functioning correctly.
33
+
34
+ **Documentation**
35
+ This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task, you do not need to worry about defining them or importing them and can assume they are available to you.
36
+ {docstring}
37
+
38
+ **Input Code Snippet**:
39
+ ```python
40
+ ### Please decided how would you want to generate test cases. Based on incomplete code or completed version.
41
+ {code}
42
+ ```
43
+
44
+ **User Instructions**:
45
+ {question}
46
+
47
+ **Previous Feedback**:
48
+ {feedback}
49
+
50
+ **Instructions**:
51
+ 1. **Understand and Clarify**: Make sure you understand the task.
52
+ 2. **Code Execution**: Run the `execute` function with the given input from the user instructions.
53
+ 3. **Output Generation**: Print the output or save it as a file for visualization utilizing the functions you have access to.
54
+ """
55
+
56
+ VISUAL_TEST = """
57
+ **Role**: You are a machine vision expert.
58
+
59
+ **Task**: Your task is to visually inspect the output of the `execute` function and determine if the visualization of the function output looks correct given the user's instructions. If not, you can provide suggestions to improve the `execute` function to imporve it.
60
+
61
+ **Documentation**:
62
+ This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task, you do not need to worry about defining them or importing them and can assume they are available to you.
63
+ {docstring}
64
+
65
+
66
+ **Input Code Snippet**:
67
+ This is the code that
68
+ ```python
69
+ {code}
70
+ ```
71
+
72
+ **User Instructions**:
73
+ {question}
74
+
75
+ **Previous Feedback**:
76
+ {feedback}
77
+
78
+ **Instructions**:
79
+ 1. **Visual Inspection**: Examine the visual output of the `execute` function.
80
+ 2. **Evaluation**: Determine if the visualization is correct based on the user's instructions.
81
+ 3. **Feedback**: Provide feedback on the visualization and suggest improvements if necessary.
82
+ 4. **Clear Concrete Instructions**: Provide clear concrete instructions to improve the results. You can only make coding suggestions based on the either the input code snippet or the documented code provided. For example, do not say the threshold needs to be adjust, instead provide an exact value for adjusting the threshold.
83
+
84
+ Provide output in JSON format {{"finished": boolean, "feedback": "your feedback"}} where "finished" is True if the output is correct and False if not and "feedback" is your feedback.
85
+ """
86
+
87
+ FIX_BUG = """
88
+ Please re-complete the code to fix the error message. Here is the previous version:
89
+ ```python
90
+ {code}
91
+ ```
92
+
93
+ When we run this code:
94
+ ```python
95
+ {tests}
96
+ ```
97
+
98
+ It raises this error:
99
+ ```python
100
+ {result}
101
+ ```
102
+
103
+ This is previous feedback provided on the code:
104
+ {feedback}
105
+
106
+ Please fix the bug by follow the error information and only return python code. You do not need return the test cases. The re-completion code should in triple backticks format(i.e., in ```python ```).
107
+ """
108
+
109
+ TEST = """
110
+ **Role**: As a tester, your task is to create comprehensive test cases for the incomplete `execute` function. These test cases should encompass Basic, Edge, and Large Scale scenarios to ensure the code's robustness, reliability, and scalability.
111
+
112
+ **User Instructions**:
113
+ {question}
114
+
115
+ **Input Code Snippet**:
116
+ ```python
117
+ ### Please decided how would you want to generate test cases. Based on incomplete code or completed version.
118
+ {code}
119
+ ```
120
+
121
+ **1. Basic Test Cases**:
122
+ - **Objective**: To verify the fundamental functionality of the `has_close_elements` function under normal conditions.
123
+
124
+ **2. Edge Test Cases**:
125
+ - **Objective**: To evaluate the function's behavior under extreme or unusual conditions.
126
+
127
+ **3. Large Scale Test Cases**:
128
+ - **Objective**: To assess the function’s performance and scalability with large data samples.
129
+
130
+ **Instructions**:
131
+ - Implement a comprehensive set of test cases following the guidelines above.
132
+ - Ensure each test case is well-documented with comments explaining the scenario it covers.
133
+ - Pay special attention to edge cases as they often reveal hidden bugs.
134
+ - For large-scale tests, focus on the function's efficiency and performance under heavy loads.
135
+ """
@@ -0,0 +1,287 @@
1
+ """This code is based off of code from CodeGeeX https://github.com/THUDM/CodeGeeX"""
2
+
3
+ import contextlib
4
+ import faulthandler
5
+ import io
6
+ import multiprocessing
7
+ import os
8
+ import platform
9
+ import signal
10
+ import tempfile
11
+ import traceback
12
+ import typing
13
+ from pathlib import Path
14
+ from typing import Dict, Generator, List, Optional, Union
15
+
16
+ IMPORT_HELPER = """
17
+ import math
18
+ import re
19
+ import sys
20
+ import copy
21
+ import datetime
22
+ import itertools
23
+ import collections
24
+ import heapq
25
+ import statistics
26
+ import functools
27
+ import hashlib
28
+ import numpy
29
+ import numpy as np
30
+ import string
31
+ from typing import *
32
+ from collections import *
33
+ from vision_agent.tools.tools_v2 import *
34
+ """
35
+
36
+
37
+ def unsafe_execute(code: str, timeout: float, result: List) -> None:
38
+ with create_tempdir() as dir:
39
+ code_path = Path(dir) / "code.py"
40
+ with open(code_path, "w") as f:
41
+ f.write(code)
42
+
43
+ # These system calls are needed when cleaning up tempdir.
44
+ import os
45
+ import shutil
46
+
47
+ rmtree = shutil.rmtree
48
+ rmdir = os.rmdir
49
+ chdir = os.chdir
50
+
51
+ # Disable functionalities that can make destructive changes to the test.
52
+ reliability_guard()
53
+
54
+ try:
55
+ with swallow_io() as s:
56
+ with time_limit(timeout):
57
+ # WARNING
58
+ # This program exists to execute untrusted model-generated code. Although
59
+ # it is highly unlikely that model-generated code will do something overtly
60
+ # malicious in response to this test suite, model-generated code may act
61
+ # destructively due to a lack of model capability or alignment.
62
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
63
+ # does not perform destructive actions on their host or network.
64
+ # Once you have read this disclaimer and taken appropriate precautions,
65
+ # uncomment the following line and proceed at your own risk:
66
+ code = compile(code, code_path, "exec") # type: ignore
67
+ exec(code)
68
+ result.append({"output": s.getvalue(), "passed": True})
69
+ except TimeoutError:
70
+ result.append({"output": "Timed out", "passed": False})
71
+ except AssertionError:
72
+ result.append({"output": f"{traceback.format_exc()}", "passed": False})
73
+ except BaseException:
74
+ result.append({"output": f"{traceback.format_exc()}", "passed": False})
75
+
76
+ # Needed for cleaning up.
77
+ shutil.rmtree = rmtree
78
+ os.rmdir = rmdir
79
+ os.chdir = chdir
80
+
81
+ code_path.unlink()
82
+
83
+
84
+ def check_correctness(
85
+ code: str,
86
+ timeout: float = 3.0,
87
+ ) -> Dict[str, Union[str, bool]]:
88
+ """Evaluates the functional correctness of a completion by running the test suite
89
+ provided in the problem.
90
+ """
91
+
92
+ manager = multiprocessing.Manager()
93
+ result = manager.list()
94
+
95
+ # p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,))
96
+ p = multiprocessing.Process(target=unsafe_execute, args=(code, timeout, result))
97
+ p.start()
98
+ p.join(timeout=timeout + 1)
99
+ if p.is_alive():
100
+ p.kill()
101
+
102
+ if not result:
103
+ result.append("timed out")
104
+
105
+ return {
106
+ "code": code,
107
+ "result": result[0]["output"],
108
+ "passed": result[0]["passed"],
109
+ }
110
+
111
+
112
+ # Copyright (c) OpenAI (https://openai.com)
113
+
114
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
115
+ # of this software and associated documentation files (the "Software"), to deal
116
+ # in the Software without restriction, including without limitation the rights
117
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
118
+ # copies of the Software, and to permit persons to whom the Software is
119
+ # furnished to do so, subject to the following conditions:
120
+
121
+ # The above copyright notice and this permission notice shall be included in
122
+ # all copies or substantial portions of the Software.
123
+
124
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
125
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
126
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
127
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
128
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
129
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
130
+ # THE SOFTWARE.
131
+ # ============================================================================
132
+
133
+
134
+ class redirect_stdin(contextlib._RedirectStream):
135
+ _stream = "stdin"
136
+
137
+
138
+ @contextlib.contextmanager
139
+ def chdir(root: str) -> Generator[None, None, None]:
140
+ if root == ".":
141
+ yield
142
+ return
143
+ cwd = os.getcwd()
144
+ os.chdir(root)
145
+ try:
146
+ yield
147
+ except BaseException as exc:
148
+ raise exc
149
+ finally:
150
+ os.chdir(cwd)
151
+
152
+
153
+ class WriteOnlyStringIO(io.StringIO):
154
+ """StringIO that throws an exception when it's read from"""
155
+
156
+ def read(self, *args, **kwargs): # type: ignore
157
+ raise IOError
158
+
159
+ def readline(self, *args, **kwargs): # type: ignore
160
+ raise IOError
161
+
162
+ def readlines(self, *args, **kwargs): # type: ignore
163
+ raise IOError
164
+
165
+ def readable(self, *args, **kwargs): # type: ignore
166
+ """Returns True if the IO object can be read."""
167
+ return False
168
+
169
+
170
+ @contextlib.contextmanager
171
+ def create_tempdir() -> Generator[str, None, None]:
172
+ with tempfile.TemporaryDirectory() as dirname:
173
+ with chdir(dirname):
174
+ yield dirname
175
+
176
+
177
+ @contextlib.contextmanager
178
+ def swallow_io() -> Generator[WriteOnlyStringIO, None, None]:
179
+ stream = WriteOnlyStringIO()
180
+ with contextlib.redirect_stdout(stream):
181
+ with contextlib.redirect_stderr(stream):
182
+ with redirect_stdin(stream):
183
+ yield stream
184
+
185
+
186
+ @typing.no_type_check
187
+ @contextlib.contextmanager
188
+ def time_limit(seconds: float) -> Generator[None, None, None]:
189
+ def signal_handler(signum, frame):
190
+ raise TimeoutError("Timed out!")
191
+
192
+ if platform.uname().system != "Windows":
193
+ signal.setitimer(signal.ITIMER_REAL, seconds)
194
+ signal.signal(signal.SIGALRM, signal_handler)
195
+ try:
196
+ yield
197
+ finally:
198
+ signal.setitimer(signal.ITIMER_REAL, 0)
199
+
200
+
201
+ @typing.no_type_check
202
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None) -> None:
203
+ """
204
+ This disables various destructive functions and prevents the generated code
205
+ from interfering with the test (e.g. fork bomb, killing other processes,
206
+ removing filesystem files, etc.)
207
+
208
+ WARNING
209
+ This function is NOT a security sandbox. Untrusted code, including, model-
210
+ generated code, should not be blindly executed outside of one. See the
211
+ Codex paper for more information about OpenAI's code sandbox, and proceed
212
+ with caution.
213
+ """
214
+
215
+ if maximum_memory_bytes is not None:
216
+ if platform.uname().system != "Windows":
217
+ import resource
218
+
219
+ resource.setrlimit(
220
+ resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
221
+ )
222
+ resource.setrlimit(
223
+ resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
224
+ )
225
+ if not platform.uname().system == "Darwin":
226
+ resource.setrlimit(
227
+ resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
228
+ )
229
+
230
+ faulthandler.disable()
231
+
232
+ import builtins
233
+
234
+ builtins.exit = None
235
+ builtins.quit = None
236
+
237
+ import os
238
+
239
+ os.environ["OMP_NUM_THREADS"] = "1"
240
+
241
+ os.kill = None
242
+ os.system = None
243
+ # os.putenv = None # this causes numpy to fail on import
244
+ os.remove = None
245
+ os.removedirs = None
246
+ os.rmdir = None
247
+ os.fchdir = None
248
+ os.setuid = None
249
+ os.fork = None
250
+ os.forkpty = None
251
+ os.killpg = None
252
+ os.rename = None
253
+ os.renames = None
254
+ os.truncate = None
255
+ os.replace = None
256
+ os.unlink = None
257
+ os.fchmod = None
258
+ os.fchown = None
259
+ os.chmod = None
260
+ os.chown = None
261
+ os.chroot = None
262
+ os.fchdir = None
263
+ os.lchflags = None
264
+ os.lchmod = None
265
+ os.lchown = None
266
+ os.getcwd = None
267
+ os.chdir = None
268
+
269
+ import shutil
270
+
271
+ shutil.rmtree = None
272
+ shutil.move = None
273
+ shutil.chown = None
274
+
275
+ import subprocess
276
+
277
+ subprocess.Popen = None # type: ignore
278
+
279
+ __builtins__["help"] = None
280
+
281
+ import sys
282
+
283
+ sys.modules["ipdb"] = None
284
+ sys.modules["joblib"] = None
285
+ sys.modules["resource"] = None
286
+ sys.modules["psutil"] = None
287
+ sys.modules["tkinter"] = None
@@ -489,6 +489,7 @@ class VisionAgent(Agent):
489
489
  image: Optional[Union[str, Path]] = None,
490
490
  reference_data: Optional[Dict[str, str]] = None,
491
491
  visualize_output: Optional[bool] = False,
492
+ self_reflection: Optional[bool] = True,
492
493
  ) -> str:
493
494
  """Invoke the vision agent.
494
495
 
@@ -501,6 +502,7 @@ class VisionAgent(Agent):
501
502
  {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
502
503
  where the bounding box coordinates are normalized.
503
504
  visualize_output: Whether to visualize the output.
505
+ self_reflection: boolean to enable and disable self reflection.
504
506
 
505
507
  Returns:
506
508
  The result of the vision agent in text.
@@ -512,6 +514,7 @@ class VisionAgent(Agent):
512
514
  image=image,
513
515
  visualize_output=visualize_output,
514
516
  reference_data=reference_data,
517
+ self_reflection=self_reflection,
515
518
  )
516
519
 
517
520
  def log_progress(self, description: str) -> None:
@@ -538,6 +541,7 @@ class VisionAgent(Agent):
538
541
  image: Optional[Union[str, Path]] = None,
539
542
  reference_data: Optional[Dict[str, str]] = None,
540
543
  visualize_output: Optional[bool] = False,
544
+ self_reflection: Optional[bool] = True,
541
545
  ) -> Tuple[str, List[Dict]]:
542
546
  """Chat with the vision agent and return the final answer and all tool results.
543
547
 
@@ -550,6 +554,7 @@ class VisionAgent(Agent):
550
554
  {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
551
555
  where the bounding box coordinates are normalized.
552
556
  visualize_output: Whether to visualize the output.
557
+ self_reflection: boolean to enable and disable self reflection.
553
558
 
554
559
  Returns:
555
560
  A tuple where the first item is the final answer and the second item is a
@@ -625,20 +630,25 @@ class VisionAgent(Agent):
625
630
  reflection_images = [image]
626
631
  else:
627
632
  reflection_images = None
628
- reflection = self_reflect(
629
- self.reflect_model,
630
- question,
631
- self.tools,
632
- all_tool_results,
633
- final_answer,
634
- reflection_images,
635
- )
636
- self.log_progress(f"Reflection: {reflection}")
637
- parsed_reflection = parse_reflect(reflection)
638
- if parsed_reflection["Finish"]:
639
- break
633
+
634
+ if self_reflection:
635
+ reflection = self_reflect(
636
+ self.reflect_model,
637
+ question,
638
+ self.tools,
639
+ all_tool_results,
640
+ final_answer,
641
+ reflection_images,
642
+ )
643
+ self.log_progress(f"Reflection: {reflection}")
644
+ parsed_reflection = parse_reflect(reflection)
645
+ if parsed_reflection["Finish"]:
646
+ break
647
+ else:
648
+ reflections += "\n" + parsed_reflection["Reflection"]
640
649
  else:
641
- reflections += "\n" + parsed_reflection["Reflection"]
650
+ self.log_progress("Self Reflection skipped based on user request.")
651
+ break
642
652
  # '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
643
653
  self.log_progress(
644
654
  f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
@@ -660,12 +670,14 @@ class VisionAgent(Agent):
660
670
  image: Optional[Union[str, Path]] = None,
661
671
  reference_data: Optional[Dict[str, str]] = None,
662
672
  visualize_output: Optional[bool] = False,
673
+ self_reflection: Optional[bool] = True,
663
674
  ) -> str:
664
675
  answer, _ = self.chat_with_workflow(
665
676
  chat,
666
677
  image=image,
667
678
  visualize_output=visualize_output,
668
679
  reference_data=reference_data,
680
+ self_reflection=self_reflection,
669
681
  )
670
682
  return answer
671
683
 
@@ -0,0 +1,27 @@
1
+ import logging
2
+ from typing import Any, Dict
3
+
4
+ import requests
5
+
6
+ from vision_agent.type_defs import LandingaiAPIKey
7
+
8
+ _LOGGER = logging.getLogger(__name__)
9
+ _LND_API_KEY = LandingaiAPIKey().api_key
10
+ _LND_API_URL = "https://api.dev.landing.ai/v1/agent"
11
+
12
+
13
+ def _send_inference_request(
14
+ payload: Dict[str, Any], endpoint_name: str
15
+ ) -> Dict[str, Any]:
16
+ res = requests.post(
17
+ f"{_LND_API_URL}/model/{endpoint_name}",
18
+ headers={
19
+ "Content-Type": "application/json",
20
+ "apikey": _LND_API_KEY,
21
+ },
22
+ json=payload,
23
+ )
24
+ if res.status_code != 200:
25
+ _LOGGER.error(f"Request failed: {res.text}")
26
+ raise ValueError(f"Request failed: {res.text}")
27
+ return res.json()["data"] # type: ignore
@@ -20,12 +20,10 @@ from vision_agent.image_utils import (
20
20
  rle_decode,
21
21
  )
22
22
  from vision_agent.lmm import OpenAILMM
23
+ from vision_agent.tools.tool_utils import _send_inference_request
23
24
  from vision_agent.tools.video import extract_frames_from_video
24
- from vision_agent.type_defs import LandingaiAPIKey
25
25
 
26
26
  _LOGGER = logging.getLogger(__name__)
27
- _LND_API_KEY = LandingaiAPIKey().api_key
28
- _LND_API_URL = "https://api.dev.landing.ai/v1/agent"
29
27
 
30
28
 
31
29
  class Tool(ABC):
@@ -1221,20 +1219,3 @@ def register_tool(tool: Type[Tool]) -> Type[Tool]:
1221
1219
  "class": tool,
1222
1220
  }
1223
1221
  return tool
1224
-
1225
-
1226
- def _send_inference_request(
1227
- payload: Dict[str, Any], endpoint_name: str
1228
- ) -> Dict[str, Any]:
1229
- res = requests.post(
1230
- f"{_LND_API_URL}/model/{endpoint_name}",
1231
- headers={
1232
- "Content-Type": "application/json",
1233
- "apikey": _LND_API_KEY,
1234
- },
1235
- json=payload,
1236
- )
1237
- if res.status_code != 200:
1238
- _LOGGER.error(f"Request failed: {res.text}")
1239
- raise ValueError(f"Request failed: {res.text}")
1240
- return res.json()["data"] # type: ignore
@@ -0,0 +1,181 @@
1
+ import inspect
2
+ import tempfile
3
+ from importlib import resources
4
+ from typing import Any, Callable, Dict, List
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+ from vision_agent.image_utils import convert_to_b64, normalize_bbox
10
+ from vision_agent.tools.tool_utils import _send_inference_request
11
+
12
+ COLORS = [
13
+ (158, 218, 229),
14
+ (219, 219, 141),
15
+ (23, 190, 207),
16
+ (188, 189, 34),
17
+ (199, 199, 199),
18
+ (247, 182, 210),
19
+ (127, 127, 127),
20
+ (227, 119, 194),
21
+ (196, 156, 148),
22
+ (197, 176, 213),
23
+ (140, 86, 75),
24
+ (148, 103, 189),
25
+ (255, 152, 150),
26
+ (152, 223, 138),
27
+ (214, 39, 40),
28
+ (44, 160, 44),
29
+ (255, 187, 120),
30
+ (174, 199, 232),
31
+ (255, 127, 14),
32
+ (31, 119, 180),
33
+ ]
34
+
35
+
36
+ def grounding_dino(
37
+ prompt: str,
38
+ image: np.ndarray,
39
+ box_threshold: float = 0.20,
40
+ iou_threshold: float = 0.75,
41
+ ) -> List[Dict[str, Any]]:
42
+ """'grounding_dino' is a tool that can detect arbitrary objects with inputs such as
43
+ category names or referring expressions.
44
+
45
+ Parameters:
46
+ prompt (str): The prompt to ground to the image.
47
+ image (np.ndarray): The image to ground the prompt to.
48
+ box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20.
49
+ iou_threshold (float, optional): The threshold for the Intersection over Union (IoU). Defaults to 0.75.
50
+
51
+ Returns:
52
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
53
+ bounding box of the detected objects with normalized coordinates.
54
+
55
+ Example
56
+ -------
57
+ >>> grounding_dino("car. dinosaur", image)
58
+ [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}]
59
+ """
60
+ image_size = image.shape[:2]
61
+ image_b64 = convert_to_b64(Image.fromarray(image))
62
+ request_data = {
63
+ "prompt": prompt,
64
+ "image": image_b64,
65
+ "tool": "visual_grounding",
66
+ "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
67
+ }
68
+ data: Dict[str, Any] = _send_inference_request(request_data, "tools")
69
+ return_data = []
70
+ for i in range(len(data["bboxes"])):
71
+ return_data.append(
72
+ {
73
+ "score": round(data["scores"][i], 2),
74
+ "label": data["labels"][i],
75
+ "bbox": normalize_bbox(data["bboxes"][i], image_size),
76
+ }
77
+ )
78
+ return return_data
79
+
80
+
81
+ def load_image(image_path: str) -> np.ndarray:
82
+ """'load_image' is a utility function that loads an image from the given path.
83
+
84
+ Parameters:
85
+ image_path (str): The path to the image.
86
+
87
+ Returns:
88
+ np.ndarray: The image as a NumPy array.
89
+
90
+ Example
91
+ -------
92
+ >>> load_image("path/to/image.jpg")
93
+ """
94
+
95
+ image = Image.open(image_path).convert("RGB")
96
+ return np.array(image)
97
+
98
+
99
+ def save_image(image: np.ndarray) -> str:
100
+ """'save_image' is a utility function that saves an image as a temporary file.
101
+
102
+ Parameters:
103
+ image (np.ndarray): The image to save.
104
+
105
+ Returns:
106
+ str: The path to the saved image.
107
+
108
+ Example
109
+ -------
110
+ >>> save_image(image)
111
+ "/tmp/tmpabc123.png"
112
+ """
113
+
114
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
115
+ pil_image = Image.fromarray(image.astype(np.uint8))
116
+ pil_image.save(f, "PNG")
117
+ return f.name
118
+
119
+
120
+ def display_bounding_boxes(
121
+ image: np.ndarray, bboxes: List[Dict[str, Any]]
122
+ ) -> np.ndarray:
123
+ """'display_bounding_boxes' is a utility function that displays bounding boxes on an image.
124
+
125
+ Parameters:
126
+ image (np.ndarray): The image to display the bounding boxes on.
127
+ bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding boxes.
128
+
129
+ Returns:
130
+ np.ndarray: The image with the bounding boxes displayed.
131
+
132
+ Example
133
+ -------
134
+ >>> image_with_bboxes = display_bounding_boxes(image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}])
135
+ """
136
+ pil_image = Image.fromarray(image.astype(np.uint8))
137
+
138
+ color = {
139
+ label: COLORS[i % len(COLORS)]
140
+ for i, label in enumerate(set([box["label"] for box in bboxes]))
141
+ }
142
+
143
+ width, height = pil_image.size
144
+ fontsize = max(12, int(min(width, height) / 40))
145
+ draw = ImageDraw.Draw(pil_image)
146
+ font = ImageFont.truetype(
147
+ str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
148
+ fontsize,
149
+ )
150
+
151
+ for elt in bboxes:
152
+ label = elt["label"]
153
+ box = elt["bbox"]
154
+ scores = elt["score"]
155
+
156
+ box = [
157
+ int(box[0] * width),
158
+ int(box[1] * height),
159
+ int(box[2] * width),
160
+ int(box[3] * height),
161
+ ]
162
+ draw.rectangle(box, outline=color[label], width=4)
163
+ text = f"{label}: {scores:.2f}"
164
+ text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
165
+ draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
166
+ draw.text((box[0], box[1]), text, fill="black", font=font)
167
+ return np.array(pil_image.convert("RGB"))
168
+
169
+
170
+ def get_tool_documentation(funcs: List[Callable]) -> str:
171
+ docstrings = ""
172
+ for func in funcs:
173
+ docstrings += f"{func.__name__}: {inspect.signature(func)}\n{func.__doc__}\n\n"
174
+
175
+ return docstrings
176
+
177
+
178
+ TOOLS_DOCSTRING = get_tool_documentation([load_image, grounding_dino])
179
+ UTILITIES_DOCSTRING = get_tool_documentation(
180
+ [load_image, save_image, display_bounding_boxes]
181
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -150,7 +150,7 @@ you. For example:
150
150
 
151
151
  #### Custom Tools
152
152
  You can also add your own custom tools for your vision agent to use:
153
-
153
+
154
154
  ```python
155
155
  from vision_agent.tools import Tool, register_tool
156
156
  @register_tool
@@ -188,6 +188,7 @@ find an example that creates a custom tool for template matching [here](examples
188
188
  | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
189
189
  | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
190
190
  | BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. |
191
+ | MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units |
191
192
  | BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. |
192
193
  | ExtractFrames | ExtractFrames extracts frames with motion from a video. |
193
194
  | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. |
@@ -1,11 +1,14 @@
1
1
  vision_agent/__init__.py,sha256=GVLHCeK_R-zgldpbcPmOzJat-BkadvkuRCMxDvTIcXs,108
2
- vision_agent/agent/__init__.py,sha256=B4JVrbY4IRVCJfjmrgvcp7h1mTUEk8MZvL0Zmej4Ka0,127
2
+ vision_agent/agent/__init__.py,sha256=6AIN_lkqAz83dRN2WKpW7m1lQH1m-8BA2L-oUCFSKd4,163
3
3
  vision_agent/agent/agent.py,sha256=X7kON-g9ePUKumCDaYfQNBX_MEFE-ax5PnRp7-Cc5Wo,529
4
+ vision_agent/agent/agent_coder.py,sha256=65ZEF3_K2aAG0vve-Q9xOPP9uw30_vVD2Li6NE6BFRY,6073
5
+ vision_agent/agent/agent_coder_prompts.py,sha256=CJe3v7xvHQ32u3RQAXQga_Tk_4UgU64RBAMHZ3S70KY,5538
4
6
  vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMVg,11511
5
7
  vision_agent/agent/easytool_prompts.py,sha256=Bikw-PPLkm78dwywTlnv32Y1Tw6JMeC-R7oCnXWLcTk,4656
8
+ vision_agent/agent/execution.py,sha256=wX8LwXDq_0g_bTPikNiaW6nz5bUC7fUlNQsQHe_7Ww0,8582
6
9
  vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
7
10
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=DVcvT02GjY85mCjhHgJGrhI_dpUvjZhoYzYik9bkHQA,26243
11
+ vision_agent/agent/vision_agent.py,sha256=5W5Xr_h4yDMsFvIk2JWcfMlYoPYmTv3JZnrDDumuZgM,26842
9
12
  vision_agent/agent/vision_agent_prompts.py,sha256=moihXFhEzFw8xnf2sUSgd_k9eoxQam3T6XUkB0fyp5o,8570
10
13
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
14
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
@@ -16,10 +19,12 @@ vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,
16
19
  vision_agent/lmm/lmm.py,sha256=gK90vMxh0OcGSuIZQikBkDXm4pfkdFk1R2y7rtWDl84,10539
17
20
  vision_agent/tools/__init__.py,sha256=uWySwcIeQMH57PVN6lVIknTx-SFmN_J0mvn_HbGlXcQ,451
18
21
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
19
- vision_agent/tools/tools.py,sha256=kqwmKPbuSAGOWjzv2LCjsvUAp2mfRk8X5a1DrP2B4i8,47007
22
+ vision_agent/tools/tool_utils.py,sha256=kY-hBDIrapI-030nZuasXU83P6X3GR0Y_gOR32bnedw,747
23
+ vision_agent/tools/tools.py,sha256=8JzNtn_uKTyc-bztjnaGCY7ctRnfW5dRS-ppxaP-1RE,46427
24
+ vision_agent/tools/tools_v2.py,sha256=RxeaBTTkhqvATQGuYKiopeU4L2m0GbpPo-ypDmQ9UfY,5407
20
25
  vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
21
26
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
22
- vision_agent-0.2.11.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
23
- vision_agent-0.2.11.dist-info/METADATA,sha256=kg0CzT1ncFoXAg4ayP2ppStbFwHnzKAygH_t6XmKTxQ,8970
24
- vision_agent-0.2.11.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
25
- vision_agent-0.2.11.dist-info/RECORD,,
27
+ vision_agent-0.2.13.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
28
+ vision_agent-0.2.13.dist-info/METADATA,sha256=de1cx4IOvv_PK0XgFVdT0xzxejKweMug5xKAp3JwD24,9073
29
+ vision_agent-0.2.13.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
30
+ vision_agent-0.2.13.dist-info/RECORD,,