vision-agent 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/__init__.py +2 -2
- vision_agent/agent/agent.py +1 -1
- vision_agent/agent/agent_coder.py +16 -10
- vision_agent/agent/{vision_agent_v2.py → data_interpreter.py} +12 -12
- vision_agent/agent/{vision_agent_v2_prompts.py → data_interpreter_prompts.py} +3 -3
- vision_agent/agent/easytool.py +8 -8
- vision_agent/agent/easytool_v2.py +778 -0
- vision_agent/agent/easytool_v2_prompts.py +152 -0
- vision_agent/agent/reflexion.py +8 -8
- vision_agent/agent/vision_agent.py +368 -690
- vision_agent/agent/vision_agent_prompts.py +233 -149
- vision_agent/llm/llm.py +3 -4
- vision_agent/lmm/lmm.py +6 -6
- vision_agent/tools/__init__.py +21 -22
- vision_agent/tools/easytool_tools.py +1242 -0
- vision_agent/tools/tools.py +533 -1090
- vision_agent-0.2.32.dist-info/METADATA +175 -0
- vision_agent-0.2.32.dist-info/RECORD +36 -0
- vision_agent/agent/vision_agent_v3.py +0 -394
- vision_agent/agent/vision_agent_v3_prompts.py +0 -234
- vision_agent/tools/tools_v2.py +0 -685
- vision_agent-0.2.30.dist-info/METADATA +0 -226
- vision_agent-0.2.30.dist-info/RECORD +0 -36
- {vision_agent-0.2.30.dist-info → vision_agent-0.2.32.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.30.dist-info → vision_agent-0.2.32.dist-info}/WHEEL +0 -0
| @@ -1,778 +1,456 @@ | |
| 1 | 
            +
            import copy
         | 
| 1 2 | 
             
            import json
         | 
| 2 3 | 
             
            import logging
         | 
| 3 4 | 
             
            import sys
         | 
| 4 | 
            -
            import tempfile
         | 
| 5 5 | 
             
            from pathlib import Path
         | 
| 6 | 
            -
            from typing import Any, Callable, Dict, List, Optional,  | 
| 6 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Union, cast
         | 
| 7 7 |  | 
| 8 | 
            -
            from  | 
| 8 | 
            +
            from rich.console import Console
         | 
| 9 | 
            +
            from rich.syntax import Syntax
         | 
| 9 10 | 
             
            from tabulate import tabulate
         | 
| 10 11 |  | 
| 11 | 
            -
            from vision_agent.agent | 
| 12 | 
            -
            from vision_agent.agent.easytool_prompts import (
         | 
| 13 | 
            -
                ANSWER_GENERATE,
         | 
| 14 | 
            -
                ANSWER_SUMMARIZE,
         | 
| 15 | 
            -
                CHOOSE_PARAMETER,
         | 
| 16 | 
            -
                CHOOSE_TOOL,
         | 
| 17 | 
            -
                TASK_DECOMPOSE,
         | 
| 18 | 
            -
                TASK_TOPOLOGY,
         | 
| 19 | 
            -
            )
         | 
| 12 | 
            +
            from vision_agent.agent import Agent
         | 
| 20 13 | 
             
            from vision_agent.agent.vision_agent_prompts import (
         | 
| 21 | 
            -
                 | 
| 22 | 
            -
                 | 
| 23 | 
            -
                 | 
| 24 | 
            -
                 | 
| 25 | 
            -
                 | 
| 26 | 
            -
                 | 
| 14 | 
            +
                CODE,
         | 
| 15 | 
            +
                FEEDBACK,
         | 
| 16 | 
            +
                FIX_BUG,
         | 
| 17 | 
            +
                FULL_TASK,
         | 
| 18 | 
            +
                PLAN,
         | 
| 19 | 
            +
                REFLECT,
         | 
| 20 | 
            +
                SIMPLE_TEST,
         | 
| 21 | 
            +
                USER_REQ,
         | 
| 27 22 | 
             
            )
         | 
| 28 23 | 
             
            from vision_agent.llm import LLM, OpenAILLM
         | 
| 29 24 | 
             
            from vision_agent.lmm import LMM, OpenAILMM
         | 
| 30 | 
            -
            from vision_agent.tools import  | 
| 31 | 
            -
            from vision_agent.utils | 
| 32 | 
            -
             | 
| 33 | 
            -
                overlay_bboxes,
         | 
| 34 | 
            -
                overlay_heat_map,
         | 
| 35 | 
            -
                overlay_masks,
         | 
| 36 | 
            -
            )
         | 
| 25 | 
            +
            from vision_agent.tools import TOOL_DESCRIPTIONS, TOOLS_DF, UTILITIES_DOCSTRING
         | 
| 26 | 
            +
            from vision_agent.utils import Execute
         | 
| 27 | 
            +
            from vision_agent.utils.sim import Sim
         | 
| 37 28 |  | 
| 38 29 | 
             
            logging.basicConfig(stream=sys.stdout)
         | 
| 39 30 | 
             
            _LOGGER = logging.getLogger(__name__)
         | 
| 40 31 | 
             
            _MAX_TABULATE_COL_WIDTH = 80
         | 
| 32 | 
            +
            _EXECUTE = Execute(600)
         | 
| 33 | 
            +
            _CONSOLE = Console()
         | 
| 41 34 |  | 
| 42 35 |  | 
| 43 | 
            -
            def  | 
| 44 | 
            -
                 | 
| 45 | 
            -
                     | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                    .strip()
         | 
| 51 | 
            -
                )
         | 
| 52 | 
            -
                return json.loads(s)
         | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
            def change_name(name: str) -> str:
         | 
| 56 | 
            -
                change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
         | 
| 57 | 
            -
                if name in change_list:
         | 
| 58 | 
            -
                    name = "is_" + name.lower()
         | 
| 59 | 
            -
                return name
         | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
            def format_tools(tools: Dict[int, Any]) -> str:
         | 
| 63 | 
            -
                # Format this way so it's clear what the ID's are
         | 
| 64 | 
            -
                tool_str = ""
         | 
| 65 | 
            -
                for key in tools:
         | 
| 66 | 
            -
                    tool_str += f"ID: {key} - {tools[key]}\n"
         | 
| 67 | 
            -
                return tool_str
         | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
            def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
         | 
| 71 | 
            -
                usage = []
         | 
| 72 | 
            -
                name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
         | 
| 73 | 
            -
                for tool_res in tool_result:
         | 
| 74 | 
            -
                    if "tool_name" in tool_res:
         | 
| 75 | 
            -
                        usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                usage_str = ""
         | 
| 78 | 
            -
                for tool_name, tool_usage in usage:
         | 
| 79 | 
            -
                    usage_str += f"{tool_name} - {tool_usage}\n"
         | 
| 80 | 
            -
                return usage_str
         | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
            def topological_sort(tasks: List[Dict]) -> List[Dict]:
         | 
| 84 | 
            -
                in_degree = {task["id"]: 0 for task in tasks}
         | 
| 85 | 
            -
                for task in tasks:
         | 
| 86 | 
            -
                    for dep in task["dep"]:
         | 
| 87 | 
            -
                        if dep in in_degree:
         | 
| 88 | 
            -
                            in_degree[task["id"]] += 1
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                queue = [task for task in tasks if in_degree[task["id"]] == 0]
         | 
| 91 | 
            -
                sorted_order = []
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                while queue:
         | 
| 94 | 
            -
                    current = queue.pop(0)
         | 
| 95 | 
            -
                    sorted_order.append(current)
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                    for task in tasks:
         | 
| 98 | 
            -
                        if current["id"] in task["dep"]:
         | 
| 99 | 
            -
                            in_degree[task["id"]] -= 1
         | 
| 100 | 
            -
                            if in_degree[task["id"]] == 0:
         | 
| 101 | 
            -
                                queue.append(task)
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                if len(sorted_order) != len(tasks):
         | 
| 104 | 
            -
                    completed_ids = set([task["id"] for task in sorted_order])
         | 
| 105 | 
            -
                    remaining_tasks = [task for task in tasks if task["id"] not in completed_ids]
         | 
| 106 | 
            -
                    sorted_order.extend(remaining_tasks)
         | 
| 107 | 
            -
                return sorted_order
         | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
            def task_decompose(
         | 
| 111 | 
            -
                model: Union[LLM, LMM, Agent],
         | 
| 112 | 
            -
                question: str,
         | 
| 113 | 
            -
                tools: Dict[int, Any],
         | 
| 114 | 
            -
                reflections: str,
         | 
| 115 | 
            -
            ) -> Optional[Dict]:
         | 
| 116 | 
            -
                if reflections:
         | 
| 117 | 
            -
                    prompt = TASK_DECOMPOSE_DEPENDS.format(
         | 
| 118 | 
            -
                        question=question, tools=format_tools(tools), reflections=reflections
         | 
| 119 | 
            -
                    )
         | 
| 120 | 
            -
                else:
         | 
| 121 | 
            -
                    prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools))
         | 
| 122 | 
            -
                tries = 0
         | 
| 123 | 
            -
                str_result = ""
         | 
| 124 | 
            -
                while True:
         | 
| 125 | 
            -
                    try:
         | 
| 126 | 
            -
                        str_result = model(prompt)
         | 
| 127 | 
            -
                        result = parse_json(str_result)
         | 
| 128 | 
            -
                        return result["Tasks"]  # type: ignore
         | 
| 129 | 
            -
                    except Exception:
         | 
| 130 | 
            -
                        if tries > 10:
         | 
| 131 | 
            -
                            _LOGGER.error(f"Failed task_decompose on: {str_result}")
         | 
| 132 | 
            -
                            return None
         | 
| 133 | 
            -
                        tries += 1
         | 
| 134 | 
            -
                        continue
         | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
            def task_topology(
         | 
| 138 | 
            -
                model: Union[LLM, LMM, Agent], question: str, task_list: List[Dict]
         | 
| 139 | 
            -
            ) -> List[Dict[str, Any]]:
         | 
| 140 | 
            -
                prompt = TASK_TOPOLOGY.format(question=question, task_list=task_list)
         | 
| 141 | 
            -
                tries = 0
         | 
| 142 | 
            -
                str_result = ""
         | 
| 143 | 
            -
                while True:
         | 
| 144 | 
            -
                    try:
         | 
| 145 | 
            -
                        str_result = model(prompt)
         | 
| 146 | 
            -
                        result = parse_json(str_result)
         | 
| 147 | 
            -
                        for elt in result["Tasks"]:
         | 
| 148 | 
            -
                            if isinstance(elt["dep"], str):
         | 
| 149 | 
            -
                                elt["dep"] = [int(dep) for dep in elt["dep"].split(",")]
         | 
| 150 | 
            -
                            elif isinstance(elt["dep"], int):
         | 
| 151 | 
            -
                                elt["dep"] = [elt["dep"]]
         | 
| 152 | 
            -
                            elif isinstance(elt["dep"], list):
         | 
| 153 | 
            -
                                elt["dep"] = [int(dep) for dep in elt["dep"]]
         | 
| 154 | 
            -
                        return result["Tasks"]  # type: ignore
         | 
| 155 | 
            -
                    except Exception:
         | 
| 156 | 
            -
                        if tries > 10:
         | 
| 157 | 
            -
                            _LOGGER.error(f"Failed task_topology on: {str_result}")
         | 
| 158 | 
            -
                            return task_list
         | 
| 159 | 
            -
                        tries += 1
         | 
| 160 | 
            -
                        continue
         | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
            def choose_tool(
         | 
| 164 | 
            -
                model: Union[LLM, LMM, Agent],
         | 
| 165 | 
            -
                question: str,
         | 
| 166 | 
            -
                tools: Dict[int, Any],
         | 
| 167 | 
            -
                reflections: str,
         | 
| 168 | 
            -
            ) -> Optional[int]:
         | 
| 169 | 
            -
                if reflections:
         | 
| 170 | 
            -
                    prompt = CHOOSE_TOOL_DEPENDS.format(
         | 
| 171 | 
            -
                        question=question, tools=format_tools(tools), reflections=reflections
         | 
| 172 | 
            -
                    )
         | 
| 173 | 
            -
                else:
         | 
| 174 | 
            -
                    prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools))
         | 
| 175 | 
            -
                tries = 0
         | 
| 176 | 
            -
                str_result = ""
         | 
| 177 | 
            -
                while True:
         | 
| 178 | 
            -
                    try:
         | 
| 179 | 
            -
                        str_result = model(prompt)
         | 
| 180 | 
            -
                        result = parse_json(str_result)
         | 
| 181 | 
            -
                        return result["ID"]  # type: ignore
         | 
| 182 | 
            -
                    except Exception:
         | 
| 183 | 
            -
                        if tries > 10:
         | 
| 184 | 
            -
                            _LOGGER.error(f"Failed choose_tool on: {str_result}")
         | 
| 185 | 
            -
                            return None
         | 
| 186 | 
            -
                        tries += 1
         | 
| 187 | 
            -
                        continue
         | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
            def choose_parameter(
         | 
| 191 | 
            -
                model: Union[LLM, LMM, Agent],
         | 
| 192 | 
            -
                question: str,
         | 
| 193 | 
            -
                tool_usage: Dict,
         | 
| 194 | 
            -
                previous_log: str,
         | 
| 195 | 
            -
                reflections: str,
         | 
| 196 | 
            -
            ) -> Optional[Any]:
         | 
| 197 | 
            -
                # TODO: should format tool_usage
         | 
| 198 | 
            -
                if reflections:
         | 
| 199 | 
            -
                    prompt = CHOOSE_PARAMETER_DEPENDS.format(
         | 
| 200 | 
            -
                        question=question,
         | 
| 201 | 
            -
                        tool_usage=tool_usage,
         | 
| 202 | 
            -
                        previous_log=previous_log,
         | 
| 203 | 
            -
                        reflections=reflections,
         | 
| 204 | 
            -
                    )
         | 
| 205 | 
            -
                else:
         | 
| 206 | 
            -
                    prompt = CHOOSE_PARAMETER.format(
         | 
| 207 | 
            -
                        question=question, tool_usage=tool_usage, previous_log=previous_log
         | 
| 208 | 
            -
                    )
         | 
| 209 | 
            -
                tries = 0
         | 
| 210 | 
            -
                str_result = ""
         | 
| 211 | 
            -
                while True:
         | 
| 212 | 
            -
                    try:
         | 
| 213 | 
            -
                        str_result = model(prompt)
         | 
| 214 | 
            -
                        result = parse_json(str_result)
         | 
| 215 | 
            -
                        return result["Parameters"]
         | 
| 216 | 
            -
                    except Exception:
         | 
| 217 | 
            -
                        if tries > 10:
         | 
| 218 | 
            -
                            _LOGGER.error(f"Failed choose_parameter on: {str_result}")
         | 
| 219 | 
            -
                            return None
         | 
| 220 | 
            -
                        tries += 1
         | 
| 221 | 
            -
                        continue
         | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
            def answer_generate(
         | 
| 225 | 
            -
                model: Union[LLM, LMM, Agent],
         | 
| 226 | 
            -
                question: str,
         | 
| 227 | 
            -
                call_results: str,
         | 
| 228 | 
            -
                previous_log: str,
         | 
| 229 | 
            -
                reflections: str,
         | 
| 230 | 
            -
            ) -> str:
         | 
| 231 | 
            -
                if reflections:
         | 
| 232 | 
            -
                    prompt = ANSWER_GENERATE_DEPENDS.format(
         | 
| 233 | 
            -
                        question=question,
         | 
| 234 | 
            -
                        call_results=call_results,
         | 
| 235 | 
            -
                        previous_log=previous_log,
         | 
| 236 | 
            -
                        reflections=reflections,
         | 
| 237 | 
            -
                    )
         | 
| 238 | 
            -
                else:
         | 
| 239 | 
            -
                    prompt = ANSWER_GENERATE.format(
         | 
| 240 | 
            -
                        question=question, call_results=call_results, previous_log=previous_log
         | 
| 36 | 
            +
            def format_memory(memory: List[Dict[str, str]]) -> str:
         | 
| 37 | 
            +
                return FEEDBACK.format(
         | 
| 38 | 
            +
                    feedback="\n".join(
         | 
| 39 | 
            +
                        [
         | 
| 40 | 
            +
                            f"### Feedback {i}:\nCode: ```python\n{m['code']}\n```\nFeedback: {m['feedback']}\n"
         | 
| 41 | 
            +
                            for i, m in enumerate(memory)
         | 
| 42 | 
            +
                        ]
         | 
| 241 43 | 
             
                    )
         | 
| 242 | 
            -
                 | 
| 44 | 
            +
                )
         | 
| 243 45 |  | 
| 244 46 |  | 
| 245 | 
            -
            def  | 
| 246 | 
            -
                 | 
| 247 | 
            -
             | 
| 248 | 
            -
                 | 
| 249 | 
            -
                     | 
| 250 | 
            -
                        question=question, answers=answers, reflections=reflections
         | 
| 251 | 
            -
                    )
         | 
| 47 | 
            +
            def extract_code(code: str) -> str:
         | 
| 48 | 
            +
                if "\n```python" in code:
         | 
| 49 | 
            +
                    start = "\n```python"
         | 
| 50 | 
            +
                elif "```python" in code:
         | 
| 51 | 
            +
                    start = "```python"
         | 
| 252 52 | 
             
                else:
         | 
| 253 | 
            -
                     | 
| 254 | 
            -
             | 
| 53 | 
            +
                    return code
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                code = code[code.find(start) + len(start) :]
         | 
| 56 | 
            +
                code = code[: code.find("```")]
         | 
| 57 | 
            +
                if code.startswith("python\n"):
         | 
| 58 | 
            +
                    code = code[len("python\n") :]
         | 
| 59 | 
            +
                return code
         | 
| 255 60 |  | 
| 256 61 |  | 
| 257 | 
            -
            def  | 
| 62 | 
            +
            def extract_json(json_str: str) -> Dict[str, Any]:
         | 
| 258 63 | 
             
                try:
         | 
| 259 | 
            -
                     | 
| 260 | 
            -
                except  | 
| 261 | 
            -
                     | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
             | 
| 267 | 
            -
             | 
| 268 | 
            -
             | 
| 269 | 
            -
                 | 
| 270 | 
            -
             | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 273 | 
            -
             | 
| 274 | 
            -
                 | 
| 275 | 
            -
             | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
| 278 | 
            -
             | 
| 279 | 
            -
             | 
| 64 | 
            +
                    json_dict = json.loads(json_str)
         | 
| 65 | 
            +
                except json.JSONDecodeError:
         | 
| 66 | 
            +
                    if "```json" in json_str:
         | 
| 67 | 
            +
                        json_str = json_str[json_str.find("```json") + len("```json") :]
         | 
| 68 | 
            +
                        json_str = json_str[: json_str.find("```")]
         | 
| 69 | 
            +
                    elif "```" in json_str:
         | 
| 70 | 
            +
                        json_str = json_str[json_str.find("```") + len("```") :]
         | 
| 71 | 
            +
                        # get the last ``` not one from an intermediate string
         | 
| 72 | 
            +
                        json_str = json_str[: json_str.find("}```")]
         | 
| 73 | 
            +
                    json_dict = json.loads(json_str)
         | 
| 74 | 
            +
                return json_dict  # type: ignore
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def write_plan(
         | 
| 78 | 
            +
                chat: List[Dict[str, str]],
         | 
| 79 | 
            +
                tool_desc: str,
         | 
| 80 | 
            +
                working_memory: str,
         | 
| 81 | 
            +
                model: Union[LLM, LMM],
         | 
| 82 | 
            +
                media: Optional[List[Union[str, Path]]] = None,
         | 
| 83 | 
            +
            ) -> List[Dict[str, str]]:
         | 
| 84 | 
            +
                chat = copy.deepcopy(chat)
         | 
| 85 | 
            +
                if chat[-1]["role"] != "user":
         | 
| 86 | 
            +
                    raise ValueError("Last chat message must be from the user.")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                user_request = chat[-1]["content"]
         | 
| 89 | 
            +
                context = USER_REQ.format(user_request=user_request)
         | 
| 90 | 
            +
                prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
         | 
| 91 | 
            +
                chat[-1]["content"] = prompt
         | 
| 92 | 
            +
                if isinstance(model, OpenAILMM):
         | 
| 93 | 
            +
                    return extract_json(model.chat(chat, images=media))["plan"]  # type: ignore
         | 
| 94 | 
            +
                else:
         | 
| 95 | 
            +
                    return extract_json(model.chat(chat))["plan"]  # type: ignore
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def reflect(
         | 
| 99 | 
            +
                chat: List[Dict[str, str]],
         | 
| 100 | 
            +
                plan: str,
         | 
| 101 | 
            +
                code: str,
         | 
| 102 | 
            +
                model: LLM,
         | 
| 103 | 
            +
            ) -> Dict[str, Union[str, bool]]:
         | 
| 104 | 
            +
                chat = copy.deepcopy(chat)
         | 
| 105 | 
            +
                if chat[-1]["role"] != "user":
         | 
| 106 | 
            +
                    raise ValueError("Last chat message must be from the user.")
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                user_request = chat[-1]["content"]
         | 
| 109 | 
            +
                context = USER_REQ.format(user_request=user_request)
         | 
| 110 | 
            +
                prompt = REFLECT.format(context=context, plan=plan, code=code)
         | 
| 111 | 
            +
                chat[-1]["content"] = prompt
         | 
| 112 | 
            +
                return extract_json(model.chat(chat))
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def write_and_test_code(
         | 
| 116 | 
            +
                task: str,
         | 
| 117 | 
            +
                tool_info: str,
         | 
| 118 | 
            +
                tool_utils: str,
         | 
| 119 | 
            +
                working_memory: str,
         | 
| 120 | 
            +
                coder: LLM,
         | 
| 121 | 
            +
                tester: LLM,
         | 
| 122 | 
            +
                debugger: LLM,
         | 
| 123 | 
            +
                log_progress: Callable[[Dict[str, Any]], None],
         | 
| 124 | 
            +
                verbosity: int = 0,
         | 
| 125 | 
            +
                max_retries: int = 3,
         | 
| 126 | 
            +
                input_media: Optional[Union[str, Path]] = None,
         | 
| 127 | 
            +
            ) -> Dict[str, Any]:
         | 
| 128 | 
            +
                code = extract_code(
         | 
| 129 | 
            +
                    coder(CODE.format(docstring=tool_info, question=task, feedback=working_memory))
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
                test = extract_code(
         | 
| 132 | 
            +
                    tester(
         | 
| 133 | 
            +
                        SIMPLE_TEST.format(
         | 
| 134 | 
            +
                            docstring=tool_utils,
         | 
| 135 | 
            +
                            question=task,
         | 
| 136 | 
            +
                            code=code,
         | 
| 137 | 
            +
                            feedback=working_memory,
         | 
| 138 | 
            +
                            media=input_media,
         | 
| 139 | 
            +
                        )
         | 
| 140 | 
            +
                    )
         | 
| 280 141 | 
             
                )
         | 
| 281 | 
            -
                if (
         | 
| 282 | 
            -
                    issubclass(type(reflect_model), LMM)
         | 
| 283 | 
            -
                    and images is not None
         | 
| 284 | 
            -
                    and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
         | 
| 285 | 
            -
                ):
         | 
| 286 | 
            -
                    return reflect_model(prompt, images=images)  # type: ignore
         | 
| 287 | 
            -
                return reflect_model(prompt)
         | 
| 288 142 |  | 
| 143 | 
            +
                success, result = _EXECUTE.run_isolation(f"{code}\n{test}")
         | 
| 144 | 
            +
                if verbosity == 2:
         | 
| 145 | 
            +
                    _LOGGER.info("Initial code and tests:")
         | 
| 146 | 
            +
                    log_progress(
         | 
| 147 | 
            +
                        {
         | 
| 148 | 
            +
                            "log": "Code:",
         | 
| 149 | 
            +
                            "code": code,
         | 
| 150 | 
            +
                        }
         | 
| 151 | 
            +
                    )
         | 
| 152 | 
            +
                    log_progress(
         | 
| 153 | 
            +
                        {
         | 
| 154 | 
            +
                            "log": "Test:",
         | 
| 155 | 
            +
                            "code": test,
         | 
| 156 | 
            +
                        }
         | 
| 157 | 
            +
                    )
         | 
| 158 | 
            +
                    _CONSOLE.print(
         | 
| 159 | 
            +
                        Syntax(f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True)
         | 
| 160 | 
            +
                    )
         | 
| 161 | 
            +
                    log_progress(
         | 
| 162 | 
            +
                        {
         | 
| 163 | 
            +
                            "log": "Result:",
         | 
| 164 | 
            +
                            "result": result,
         | 
| 165 | 
            +
                        }
         | 
| 166 | 
            +
                    )
         | 
| 167 | 
            +
                    _LOGGER.info(f"Initial result: {result}")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                count = 0
         | 
| 170 | 
            +
                new_working_memory = []
         | 
| 171 | 
            +
                while not success and count < max_retries:
         | 
| 172 | 
            +
                    fixed_code_and_test = extract_json(
         | 
| 173 | 
            +
                        debugger(
         | 
| 174 | 
            +
                            FIX_BUG.format(
         | 
| 175 | 
            +
                                code=code, tests=test, result=result, feedback=working_memory
         | 
| 176 | 
            +
                            )
         | 
| 177 | 
            +
                        )
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
                    if fixed_code_and_test["code"].strip() != "":
         | 
| 180 | 
            +
                        code = extract_code(fixed_code_and_test["code"])
         | 
| 181 | 
            +
                    if fixed_code_and_test["test"].strip() != "":
         | 
| 182 | 
            +
                        test = extract_code(fixed_code_and_test["test"])
         | 
| 183 | 
            +
                    new_working_memory.append(
         | 
| 184 | 
            +
                        {"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
         | 
| 185 | 
            +
                    )
         | 
| 289 186 |  | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 | 
            -
             | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
                # LMMs have a hard time following directions, so make the criteria less strict
         | 
| 297 | 
            -
                finish = (
         | 
| 298 | 
            -
                    "finish" in reflect.lower() and len(reflect) < 100
         | 
| 299 | 
            -
                ) or "finish" in reflect.lower()[-10:]
         | 
| 300 | 
            -
                return {"Finish": finish, "Reflection": reflect}
         | 
| 301 | 
            -
             | 
| 302 | 
            -
             | 
| 303 | 
            -
            def _handle_extract_frames(
         | 
| 304 | 
            -
                image_to_data: Dict[str, Dict], tool_result: Dict
         | 
| 305 | 
            -
            ) -> Dict[str, Dict]:
         | 
| 306 | 
            -
                image_to_data = image_to_data.copy()
         | 
| 307 | 
            -
                # handle extract_frames_ case, useful if it extracts frames but doesn't do
         | 
| 308 | 
            -
                # any following processing
         | 
| 309 | 
            -
                for video_file_output in tool_result["call_results"]:
         | 
| 310 | 
            -
                    # When the video tool is run with wrong parameters, exit the loop
         | 
| 311 | 
            -
                    if not isinstance(video_file_output, tuple) or len(video_file_output) < 2:
         | 
| 312 | 
            -
                        break
         | 
| 313 | 
            -
                    for frame, _ in video_file_output:
         | 
| 314 | 
            -
                        image = frame
         | 
| 315 | 
            -
                        if image not in image_to_data:
         | 
| 316 | 
            -
                            image_to_data[image] = {
         | 
| 317 | 
            -
                                "bboxes": [],
         | 
| 318 | 
            -
                                "masks": [],
         | 
| 319 | 
            -
                                "heat_map": [],
         | 
| 320 | 
            -
                                "labels": [],
         | 
| 321 | 
            -
                                "scores": [],
         | 
| 187 | 
            +
                    success, result = _EXECUTE.run_isolation(f"{code}\n{test}")
         | 
| 188 | 
            +
                    if verbosity == 2:
         | 
| 189 | 
            +
                        log_progress(
         | 
| 190 | 
            +
                            {
         | 
| 191 | 
            +
                                "log": f"Debug attempt {count + 1}, reflection:",
         | 
| 192 | 
            +
                                "result": fixed_code_and_test["reflections"],
         | 
| 322 193 | 
             
                            }
         | 
| 323 | 
            -
             | 
| 324 | 
            -
             | 
| 325 | 
            -
             | 
| 326 | 
            -
             | 
| 327 | 
            -
             | 
| 328 | 
            -
             | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
| 331 | 
            -
             | 
| 332 | 
            -
             | 
| 333 | 
            -
             | 
| 334 | 
            -
             | 
| 335 | 
            -
             | 
| 336 | 
            -
             | 
| 337 | 
            -
                         | 
| 338 | 
            -
             | 
| 339 | 
            -
             | 
| 340 | 
            -
                    if len(tool_result["parameters"]) < 1 or (
         | 
| 341 | 
            -
                        "image" not in tool_result["parameters"][0]
         | 
| 342 | 
            -
                    ):
         | 
| 343 | 
            -
                        return image_to_data
         | 
| 344 | 
            -
             | 
| 345 | 
            -
                for param, call_result in zip(parameters, tool_result["call_results"]):
         | 
| 346 | 
            -
                    # Calls can fail, so we need to check if the call was successful. It can either:
         | 
| 347 | 
            -
                    # 1. return a str or some error that's not a dictionary
         | 
| 348 | 
            -
                    # 2. return a dictionary but not have the necessary keys
         | 
| 349 | 
            -
             | 
| 350 | 
            -
                    if not isinstance(call_result, dict) or (
         | 
| 351 | 
            -
                        "bboxes" not in call_result
         | 
| 352 | 
            -
                        and "mask" not in call_result
         | 
| 353 | 
            -
                        and "heat_map" not in call_result
         | 
| 354 | 
            -
                    ):
         | 
| 355 | 
            -
                        return image_to_data
         | 
| 356 | 
            -
             | 
| 357 | 
            -
                    # if the call was successful, then we can add the image data
         | 
| 358 | 
            -
                    image = param["image"]
         | 
| 359 | 
            -
                    if image not in image_to_data:
         | 
| 360 | 
            -
                        image_to_data[image] = {
         | 
| 361 | 
            -
                            "bboxes": [],
         | 
| 362 | 
            -
                            "masks": [],
         | 
| 363 | 
            -
                            "heat_map": [],
         | 
| 364 | 
            -
                            "labels": [],
         | 
| 365 | 
            -
                            "scores": [],
         | 
| 366 | 
            -
                        }
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        _LOGGER.info(
         | 
| 196 | 
            +
                            f"Debug attempt {count + 1}, reflection: {fixed_code_and_test['reflections']}"
         | 
| 197 | 
            +
                        )
         | 
| 198 | 
            +
                        _CONSOLE.print(
         | 
| 199 | 
            +
                            Syntax(
         | 
| 200 | 
            +
                                f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True
         | 
| 201 | 
            +
                            )
         | 
| 202 | 
            +
                        )
         | 
| 203 | 
            +
                        log_progress(
         | 
| 204 | 
            +
                            {
         | 
| 205 | 
            +
                                "log": "Debug result:",
         | 
| 206 | 
            +
                                "result": result,
         | 
| 207 | 
            +
                            }
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                        _LOGGER.info(f"Debug result: {result}")
         | 
| 210 | 
            +
                    count += 1
         | 
| 367 211 |  | 
| 368 | 
            -
             | 
| 369 | 
            -
                     | 
| 370 | 
            -
                     | 
| 371 | 
            -
             | 
| 372 | 
            -
                     | 
| 373 | 
            -
                     | 
| 374 | 
            -
             | 
| 375 | 
            -
             | 
| 376 | 
            -
             | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
                 | 
| 387 | 
            -
             | 
| 388 | 
            -
                 | 
| 389 | 
            -
             | 
| 390 | 
            -
             | 
| 391 | 
            -
                 | 
| 392 | 
            -
                 | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
                 | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
             | 
| 401 | 
            -
                         | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
                        "ocr_",
         | 
| 407 | 
            -
                    ]:
         | 
| 408 | 
            -
                        continue
         | 
| 409 | 
            -
             | 
| 410 | 
            -
                    if tool_result["tool_name"] == "extract_frames_":
         | 
| 411 | 
            -
                        image_to_data = _handle_extract_frames(image_to_data, tool_result)
         | 
| 412 | 
            -
                    else:
         | 
| 413 | 
            -
                        image_to_data = _handle_viz_tools(image_to_data, tool_result)
         | 
| 414 | 
            -
             | 
| 415 | 
            -
                visualized_images = []
         | 
| 416 | 
            -
                for image_str in image_to_data:
         | 
| 417 | 
            -
                    image_path = Path(image_str)
         | 
| 418 | 
            -
                    image_data = image_to_data[image_str]
         | 
| 419 | 
            -
                    if "_counting_" in tool_result["tool_name"]:
         | 
| 420 | 
            -
                        image = overlay_heat_map(image_path, image_data)
         | 
| 421 | 
            -
                    else:
         | 
| 422 | 
            -
                        image = overlay_masks(image_path, image_data)
         | 
| 423 | 
            -
                        image = overlay_bboxes(image, image_data)
         | 
| 424 | 
            -
                    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
         | 
| 425 | 
            -
                        image.save(f.name)
         | 
| 426 | 
            -
                        visualized_images.append(f.name)
         | 
| 427 | 
            -
                return visualized_images
         | 
| 212 | 
            +
                if verbosity >= 1:
         | 
| 213 | 
            +
                    _LOGGER.info("Final code and tests:")
         | 
| 214 | 
            +
                    _CONSOLE.print(
         | 
| 215 | 
            +
                        Syntax(f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True)
         | 
| 216 | 
            +
                    )
         | 
| 217 | 
            +
                    _LOGGER.info(f"Final Result: {result}")
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                return {
         | 
| 220 | 
            +
                    "code": code,
         | 
| 221 | 
            +
                    "test": test,
         | 
| 222 | 
            +
                    "success": success,
         | 
| 223 | 
            +
                    "test_result": result,
         | 
| 224 | 
            +
                    "working_memory": new_working_memory,
         | 
| 225 | 
            +
                }
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def retrieve_tools(
         | 
| 229 | 
            +
                plan: List[Dict[str, str]],
         | 
| 230 | 
            +
                tool_recommender: Sim,
         | 
| 231 | 
            +
                log_progress: Callable[[Dict[str, Any]], None],
         | 
| 232 | 
            +
                verbosity: int = 0,
         | 
| 233 | 
            +
            ) -> str:
         | 
| 234 | 
            +
                tool_info = []
         | 
| 235 | 
            +
                tool_desc = []
         | 
| 236 | 
            +
                for task in plan:
         | 
| 237 | 
            +
                    tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
         | 
| 238 | 
            +
                    tool_info.extend([e["doc"] for e in tools])
         | 
| 239 | 
            +
                    tool_desc.extend([e["desc"] for e in tools])
         | 
| 240 | 
            +
                if verbosity == 2:
         | 
| 241 | 
            +
                    log_progress(
         | 
| 242 | 
            +
                        {
         | 
| 243 | 
            +
                            "log": "Retrieved tools:",
         | 
| 244 | 
            +
                            "tools": tool_desc,
         | 
| 245 | 
            +
                        }
         | 
| 246 | 
            +
                    )
         | 
| 247 | 
            +
                    _LOGGER.info(f"Tools: {tool_desc}")
         | 
| 248 | 
            +
                tool_info_set = set(tool_info)
         | 
| 249 | 
            +
                return "\n\n".join(tool_info_set)
         | 
| 428 250 |  | 
| 429 251 |  | 
| 430 252 | 
             
            class VisionAgent(Agent):
         | 
| 431 | 
            -
                 | 
| 432 | 
            -
                 | 
| 433 | 
            -
                 | 
| 434 | 
            -
                https://arxiv.org/abs/ | 
| 435 | 
            -
                 | 
| 436 | 
            -
                and final results, if not it will redo the task with this newly added reflection.
         | 
| 253 | 
            +
                """Vision Agent is an agentic framework that can output code based on a user
         | 
| 254 | 
            +
                request. It can plan tasks, retrieve relevant tools, write code, write tests and
         | 
| 255 | 
            +
                reflect on failed test cases to debug code. It is inspired by AgentCoder
         | 
| 256 | 
            +
                https://arxiv.org/abs/2312.13010 and Data Interpeter
         | 
| 257 | 
            +
                https://arxiv.org/abs/2402.18679
         | 
| 437 258 |  | 
| 438 259 | 
             
                Example
         | 
| 439 260 | 
             
                -------
         | 
| 440 | 
            -
                    >>> from vision_agent | 
| 261 | 
            +
                    >>> from vision_agent import VisionAgent
         | 
| 441 262 | 
             
                    >>> agent = VisionAgent()
         | 
| 442 | 
            -
                    >>>  | 
| 443 | 
            -
                    >>> print(resp)
         | 
| 444 | 
            -
                    "The total cost is $57.50."
         | 
| 263 | 
            +
                    >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
         | 
| 445 264 | 
             
                """
         | 
| 446 265 |  | 
| 447 266 | 
             
                def __init__(
         | 
| 448 267 | 
             
                    self,
         | 
| 449 | 
            -
                     | 
| 450 | 
            -
                     | 
| 451 | 
            -
                     | 
| 452 | 
            -
                     | 
| 453 | 
            -
                     | 
| 268 | 
            +
                    planner: Optional[LLM] = None,
         | 
| 269 | 
            +
                    coder: Optional[LLM] = None,
         | 
| 270 | 
            +
                    tester: Optional[LLM] = None,
         | 
| 271 | 
            +
                    debugger: Optional[LLM] = None,
         | 
| 272 | 
            +
                    tool_recommender: Optional[Sim] = None,
         | 
| 273 | 
            +
                    verbosity: int = 0,
         | 
| 454 274 | 
             
                    report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
         | 
| 455 | 
            -
                ):
         | 
| 456 | 
            -
                    """ | 
| 275 | 
            +
                ) -> None:
         | 
| 276 | 
            +
                    """Initialize the Vision Agent.
         | 
| 457 277 |  | 
| 458 278 | 
             
                    Parameters:
         | 
| 459 | 
            -
                         | 
| 460 | 
            -
                         | 
| 461 | 
            -
                         | 
| 462 | 
            -
                         | 
| 463 | 
            -
                         | 
| 464 | 
            -
                         | 
| 279 | 
            +
                        planner (Optional[LLM]): The planner model to use. Defaults to OpenAILLM.
         | 
| 280 | 
            +
                        coder (Optional[LLM]): The coder model to use. Defaults to OpenAILLM.
         | 
| 281 | 
            +
                        tester (Optional[LLM]): The tester model to use. Defaults to OpenAILLM.
         | 
| 282 | 
            +
                        debugger (Optional[LLM]): The debugger model to
         | 
| 283 | 
            +
                        tool_recommender (Optional[Sim]): The tool recommender model to use.
         | 
| 284 | 
            +
                        verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
         | 
| 285 | 
            +
                            highest verbosity level which will output all intermediate debugging
         | 
| 286 | 
            +
                            code.
         | 
| 287 | 
            +
                        report_progress_callback: a callback to report the progress of the agent.
         | 
| 288 | 
            +
                            This is useful for streaming logs in a web application where multiple
         | 
| 289 | 
            +
                            VisionAgent instances are running in parallel. This callback ensures
         | 
| 290 | 
            +
                            that the progress are not mixed up.
         | 
| 465 291 | 
             
                    """
         | 
| 466 | 
            -
             | 
| 467 | 
            -
             | 
| 468 | 
            -
                        if  | 
| 469 | 
            -
                        else task_model
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    self.planner = (
         | 
| 294 | 
            +
                        OpenAILLM(temperature=0.0, json_mode=True) if planner is None else planner
         | 
| 470 295 | 
             
                    )
         | 
| 471 | 
            -
                    self. | 
| 472 | 
            -
             | 
| 473 | 
            -
             | 
| 474 | 
            -
                        else  | 
| 296 | 
            +
                    self.coder = OpenAILLM(temperature=0.0) if coder is None else coder
         | 
| 297 | 
            +
                    self.tester = OpenAILLM(temperature=0.0) if tester is None else tester
         | 
| 298 | 
            +
                    self.debugger = (
         | 
| 299 | 
            +
                        OpenAILLM(temperature=0.0, json_mode=True) if debugger is None else debugger
         | 
| 475 300 | 
             
                    )
         | 
| 476 | 
            -
             | 
| 477 | 
            -
             | 
| 478 | 
            -
                         | 
| 479 | 
            -
                         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    self.tool_recommender = (
         | 
| 303 | 
            +
                        Sim(TOOLS_DF, sim_key="desc")
         | 
| 304 | 
            +
                        if tool_recommender is None
         | 
| 305 | 
            +
                        else tool_recommender
         | 
| 480 306 | 
             
                    )
         | 
| 481 | 
            -
                    self. | 
| 482 | 
            -
                    self. | 
| 307 | 
            +
                    self.verbosity = verbosity
         | 
| 308 | 
            +
                    self.max_retries = 2
         | 
| 483 309 | 
             
                    self.report_progress_callback = report_progress_callback
         | 
| 484 | 
            -
                    if verbose:
         | 
| 485 | 
            -
                        _LOGGER.setLevel(logging.INFO)
         | 
| 486 310 |  | 
| 487 311 | 
             
                def __call__(
         | 
| 488 312 | 
             
                    self,
         | 
| 489 313 | 
             
                    input: Union[List[Dict[str, str]], str],
         | 
| 490 | 
            -
                     | 
| 491 | 
            -
                    reference_data: Optional[Dict[str, str]] = None,
         | 
| 492 | 
            -
                    visualize_output: Optional[bool] = False,
         | 
| 493 | 
            -
                    self_reflection: Optional[bool] = True,
         | 
| 314 | 
            +
                    media: Optional[Union[str, Path]] = None,
         | 
| 494 315 | 
             
                ) -> str:
         | 
| 495 | 
            -
                    """ | 
| 316 | 
            +
                    """Chat with Vision Agent and return intermediate information regarding the task.
         | 
| 496 317 |  | 
| 497 318 | 
             
                    Parameters:
         | 
| 498 | 
            -
                        chat: A conversation in the format of
         | 
| 319 | 
            +
                        chat (List[Dict[str, str]]): A conversation in the format of
         | 
| 499 320 | 
             
                            [{"role": "user", "content": "describe your task here..."}].
         | 
| 500 | 
            -
                         | 
| 501 | 
            -
                         | 
| 502 | 
            -
                            box in the format of:
         | 
| 503 | 
            -
                            {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
         | 
| 504 | 
            -
                            where the bounding box coordinates are normalized.
         | 
| 505 | 
            -
                        visualize_output: Whether to visualize the output.
         | 
| 506 | 
            -
                        self_reflection: boolean to enable and disable self reflection.
         | 
| 321 | 
            +
                        media (Optional[Union[str, Path]]): The media file to be used in the task.
         | 
| 322 | 
            +
                        self_reflection (bool): Whether to reflect on the task and debug the code.
         | 
| 507 323 |  | 
| 508 324 | 
             
                    Returns:
         | 
| 509 | 
            -
                        The  | 
| 325 | 
            +
                        str: The code output by the Vision Agent.
         | 
| 510 326 | 
             
                    """
         | 
| 327 | 
            +
             | 
| 511 328 | 
             
                    if isinstance(input, str):
         | 
| 512 329 | 
             
                        input = [{"role": "user", "content": input}]
         | 
| 513 | 
            -
                     | 
| 514 | 
            -
             | 
| 515 | 
            -
             | 
| 516 | 
            -
                        visualize_output=visualize_output,
         | 
| 517 | 
            -
                        reference_data=reference_data,
         | 
| 518 | 
            -
                        self_reflection=self_reflection,
         | 
| 519 | 
            -
                    )
         | 
| 520 | 
            -
             | 
| 521 | 
            -
                def log_progress(self, data: Dict[str, Any]) -> None:
         | 
| 522 | 
            -
                    _LOGGER.info(data)
         | 
| 523 | 
            -
                    if self.report_progress_callback:
         | 
| 524 | 
            -
                        self.report_progress_callback(data)
         | 
| 525 | 
            -
             | 
| 526 | 
            -
                def _report_visualization_via_callback(
         | 
| 527 | 
            -
                    self, images: Sequence[Union[str, Path]]
         | 
| 528 | 
            -
                ) -> None:
         | 
| 529 | 
            -
                    """This is intended for streaming the visualization images via the callback to the client side."""
         | 
| 530 | 
            -
                    if self.report_progress_callback:
         | 
| 531 | 
            -
                        self.report_progress_callback({"log": "<VIZ>"})
         | 
| 532 | 
            -
                        if images:
         | 
| 533 | 
            -
                            for img in images:
         | 
| 534 | 
            -
                                self.report_progress_callback(
         | 
| 535 | 
            -
                                    {"log": f"<IMG>base:64{convert_to_b64(img)}</IMG>"}
         | 
| 536 | 
            -
                                )
         | 
| 537 | 
            -
                        self.report_progress_callback({"log": "</VIZ>"})
         | 
| 330 | 
            +
                    results = self.chat_with_workflow(input, media)
         | 
| 331 | 
            +
                    results.pop("working_memory")
         | 
| 332 | 
            +
                    return results  # type: ignore
         | 
| 538 333 |  | 
| 539 334 | 
             
                def chat_with_workflow(
         | 
| 540 335 | 
             
                    self,
         | 
| 541 336 | 
             
                    chat: List[Dict[str, str]],
         | 
| 542 | 
            -
                     | 
| 543 | 
            -
                     | 
| 544 | 
            -
             | 
| 545 | 
            -
                     | 
| 546 | 
            -
                ) -> Tuple[str, List[Dict]]:
         | 
| 547 | 
            -
                    """Chat with the vision agent and return the final answer and all tool results.
         | 
| 337 | 
            +
                    media: Optional[Union[str, Path]] = None,
         | 
| 338 | 
            +
                    self_reflection: bool = False,
         | 
| 339 | 
            +
                ) -> Dict[str, Any]:
         | 
| 340 | 
            +
                    """Chat with Vision Agent and return intermediate information regarding the task.
         | 
| 548 341 |  | 
| 549 342 | 
             
                    Parameters:
         | 
| 550 | 
            -
                        chat: A conversation in the format of
         | 
| 343 | 
            +
                        chat (List[Dict[str, str]]): A conversation in the format of
         | 
| 551 344 | 
             
                            [{"role": "user", "content": "describe your task here..."}].
         | 
| 552 | 
            -
                         | 
| 553 | 
            -
                         | 
| 554 | 
            -
                            box in the format of:
         | 
| 555 | 
            -
                            {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
         | 
| 556 | 
            -
                            where the bounding box coordinates are normalized.
         | 
| 557 | 
            -
                        visualize_output: Whether to visualize the output.
         | 
| 558 | 
            -
                        self_reflection: boolean to enable and disable self reflection.
         | 
| 345 | 
            +
                        media (Optional[Union[str, Path]]): The media file to be used in the task.
         | 
| 346 | 
            +
                        self_reflection (bool): Whether to reflect on the task and debug the code.
         | 
| 559 347 |  | 
| 560 348 | 
             
                    Returns:
         | 
| 561 | 
            -
                         | 
| 562 | 
            -
             | 
| 563 | 
            -
                        contains the visualized output.
         | 
| 349 | 
            +
                        Dict[str, Any]: A dictionary containing the code, test, test result, plan,
         | 
| 350 | 
            +
                            and working memory of the agent.
         | 
| 564 351 | 
             
                    """
         | 
| 565 | 
            -
                    if len(chat) == 0:
         | 
| 566 | 
            -
                        raise ValueError("Input cannot be empty.")
         | 
| 567 | 
            -
             | 
| 568 | 
            -
                    question = chat[0]["content"]
         | 
| 569 | 
            -
                    if image:
         | 
| 570 | 
            -
                        question += f" Image name: {image}"
         | 
| 571 | 
            -
                    if reference_data:
         | 
| 572 | 
            -
                        question += (
         | 
| 573 | 
            -
                            f" Reference image: {reference_data['image']}"
         | 
| 574 | 
            -
                            if "image" in reference_data
         | 
| 575 | 
            -
                            else ""
         | 
| 576 | 
            -
                        )
         | 
| 577 | 
            -
                        question += (
         | 
| 578 | 
            -
                            f" Reference mask: {reference_data['mask']}"
         | 
| 579 | 
            -
                            if "mask" in reference_data
         | 
| 580 | 
            -
                            else ""
         | 
| 581 | 
            -
                        )
         | 
| 582 | 
            -
                        question += (
         | 
| 583 | 
            -
                            f" Reference bbox: {reference_data['bbox']}"
         | 
| 584 | 
            -
                            if "bbox" in reference_data
         | 
| 585 | 
            -
                            else ""
         | 
| 586 | 
            -
                        )
         | 
| 587 | 
            -
             | 
| 588 | 
            -
                    reflections = ""
         | 
| 589 | 
            -
                    final_answer = ""
         | 
| 590 | 
            -
                    all_tool_results: List[Dict] = []
         | 
| 591 352 |  | 
| 592 | 
            -
                     | 
| 593 | 
            -
                         | 
| 594 | 
            -
             | 
| 353 | 
            +
                    if len(chat) == 0:
         | 
| 354 | 
            +
                        raise ValueError("Chat cannot be empty.")
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if media is not None:
         | 
| 357 | 
            +
                        for chat_i in chat:
         | 
| 358 | 
            +
                            if chat_i["role"] == "user":
         | 
| 359 | 
            +
                                chat_i["content"] += f" Image name {media}"
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    code = ""
         | 
| 362 | 
            +
                    test = ""
         | 
| 363 | 
            +
                    working_memory: List[Dict[str, str]] = []
         | 
| 364 | 
            +
                    results = {"code": "", "test": "", "plan": []}
         | 
| 365 | 
            +
                    plan = []
         | 
| 366 | 
            +
                    success = False
         | 
| 367 | 
            +
                    retries = 0
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    while not success and retries < self.max_retries:
         | 
| 370 | 
            +
                        plan_i = write_plan(
         | 
| 371 | 
            +
                            chat,
         | 
| 372 | 
            +
                            TOOL_DESCRIPTIONS,
         | 
| 373 | 
            +
                            format_memory(working_memory),
         | 
| 374 | 
            +
                            self.planner,
         | 
| 375 | 
            +
                            media=[media] if media else None,
         | 
| 595 376 | 
             
                        )
         | 
| 596 | 
            -
             | 
| 597 | 
            -
                         | 
| 598 | 
            -
             | 
| 599 | 
            -
             | 
| 600 | 
            -
             | 
| 601 | 
            -
             | 
| 602 | 
            -
             | 
| 603 | 
            -
             | 
| 604 | 
            -
                        for task in task_list:
         | 
| 605 | 
            -
                            task_str = task["task"]
         | 
| 606 | 
            -
                            previous_log = str(task_depend)
         | 
| 607 | 
            -
                            tool_results, call_results = self.retrieval(
         | 
| 608 | 
            -
                                self.task_model,
         | 
| 609 | 
            -
                                task_str,
         | 
| 610 | 
            -
                                self.tools,
         | 
| 611 | 
            -
                                previous_log,
         | 
| 612 | 
            -
                                reflections,
         | 
| 613 | 
            -
                            )
         | 
| 614 | 
            -
                            answer = answer_generate(
         | 
| 615 | 
            -
                                self.answer_model, task_str, call_results, previous_log, reflections
         | 
| 377 | 
            +
                        plan_i_str = "\n-".join([e["instructions"] for e in plan_i])
         | 
| 378 | 
            +
                        if self.verbosity >= 1:
         | 
| 379 | 
            +
                            self.log_progress(
         | 
| 380 | 
            +
                                {
         | 
| 381 | 
            +
                                    "log": "Going to run the following plan(s) in sequence:\n",
         | 
| 382 | 
            +
                                    "plan": plan_i,
         | 
| 383 | 
            +
                                }
         | 
| 616 384 | 
             
                            )
         | 
| 617 385 |  | 
| 618 | 
            -
                             | 
| 619 | 
            -
             | 
| 386 | 
            +
                            _LOGGER.info(
         | 
| 387 | 
            +
                                f"""
         | 
| 388 | 
            +
            {tabulate(tabular_data=plan_i, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
         | 
| 389 | 
            +
                            )
         | 
| 620 390 |  | 
| 621 | 
            -
             | 
| 622 | 
            -
                             | 
| 623 | 
            -
                             | 
| 624 | 
            -
                             | 
| 625 | 
            -
                             | 
| 626 | 
            -
                        final_answer = answer_summarize(
         | 
| 627 | 
            -
                            self.answer_model, question, answers, reflections
         | 
| 391 | 
            +
                        tool_info = retrieve_tools(
         | 
| 392 | 
            +
                            plan_i,
         | 
| 393 | 
            +
                            self.tool_recommender,
         | 
| 394 | 
            +
                            self.log_progress,
         | 
| 395 | 
            +
                            self.verbosity,
         | 
| 628 396 | 
             
                        )
         | 
| 629 | 
            -
                         | 
| 630 | 
            -
             | 
| 631 | 
            -
             | 
| 632 | 
            -
                             | 
| 633 | 
            -
             | 
| 634 | 
            -
                             | 
| 635 | 
            -
             | 
| 636 | 
            -
                             | 
| 397 | 
            +
                        results = write_and_test_code(
         | 
| 398 | 
            +
                            FULL_TASK.format(user_request=chat[0]["content"], subtasks=plan_i_str),
         | 
| 399 | 
            +
                            tool_info,
         | 
| 400 | 
            +
                            UTILITIES_DOCSTRING,
         | 
| 401 | 
            +
                            format_memory(working_memory),
         | 
| 402 | 
            +
                            self.coder,
         | 
| 403 | 
            +
                            self.tester,
         | 
| 404 | 
            +
                            self.debugger,
         | 
| 405 | 
            +
                            self.log_progress,
         | 
| 406 | 
            +
                            verbosity=self.verbosity,
         | 
| 407 | 
            +
                            input_media=media,
         | 
| 408 | 
            +
                        )
         | 
| 409 | 
            +
                        success = cast(bool, results["success"])
         | 
| 410 | 
            +
                        code = cast(str, results["code"])
         | 
| 411 | 
            +
                        test = cast(str, results["test"])
         | 
| 412 | 
            +
                        working_memory.extend(results["working_memory"])  # type: ignore
         | 
| 413 | 
            +
                        plan.append({"code": code, "test": test, "plan": plan_i})
         | 
| 637 414 |  | 
| 638 415 | 
             
                        if self_reflection:
         | 
| 639 | 
            -
                            reflection =  | 
| 640 | 
            -
                                 | 
| 641 | 
            -
                                 | 
| 642 | 
            -
             | 
| 643 | 
            -
                                 | 
| 644 | 
            -
                                 | 
| 645 | 
            -
                                 | 
| 646 | 
            -
                            )
         | 
| 647 | 
            -
                            self.log_progress({"log": f"Reflection: {reflection}"})
         | 
| 648 | 
            -
                            parsed_reflection = parse_reflect(reflection)
         | 
| 649 | 
            -
                            if parsed_reflection["Finish"]:
         | 
| 650 | 
            -
                                break
         | 
| 651 | 
            -
                            else:
         | 
| 652 | 
            -
                                reflections += "\n" + parsed_reflection["Reflection"]
         | 
| 653 | 
            -
                        else:
         | 
| 654 | 
            -
                            self.log_progress(
         | 
| 655 | 
            -
                                {"log": "Self Reflection skipped based on user request."}
         | 
| 416 | 
            +
                            reflection = reflect(
         | 
| 417 | 
            +
                                chat,
         | 
| 418 | 
            +
                                FULL_TASK.format(
         | 
| 419 | 
            +
                                    user_request=chat[0]["content"], subtasks=plan_i_str
         | 
| 420 | 
            +
                                ),
         | 
| 421 | 
            +
                                code,
         | 
| 422 | 
            +
                                self.planner,
         | 
| 656 423 | 
             
                            )
         | 
| 657 | 
            -
                             | 
| 658 | 
            -
             | 
| 659 | 
            -
             | 
| 660 | 
            -
             | 
| 661 | 
            -
             | 
| 662 | 
            -
             | 
| 663 | 
            -
             | 
| 664 | 
            -
             | 
| 665 | 
            -
             | 
| 666 | 
            -
             | 
| 667 | 
            -
                            " | 
| 668 | 
            -
                        ]
         | 
| 669 | 
            -
                        self._report_visualization_via_callback(viz_images)
         | 
| 670 | 
            -
                        for img in viz_images:
         | 
| 671 | 
            -
                            Image.open(img).show()
         | 
| 672 | 
            -
             | 
| 673 | 
            -
                    return final_answer, all_tool_results
         | 
| 674 | 
            -
             | 
| 675 | 
            -
                def chat(
         | 
| 676 | 
            -
                    self,
         | 
| 677 | 
            -
                    chat: List[Dict[str, str]],
         | 
| 678 | 
            -
                    image: Optional[Union[str, Path]] = None,
         | 
| 679 | 
            -
                    reference_data: Optional[Dict[str, str]] = None,
         | 
| 680 | 
            -
                    visualize_output: Optional[bool] = False,
         | 
| 681 | 
            -
                    self_reflection: Optional[bool] = True,
         | 
| 682 | 
            -
                ) -> str:
         | 
| 683 | 
            -
                    answer, _ = self.chat_with_workflow(
         | 
| 684 | 
            -
                        chat,
         | 
| 685 | 
            -
                        image=image,
         | 
| 686 | 
            -
                        visualize_output=visualize_output,
         | 
| 687 | 
            -
                        reference_data=reference_data,
         | 
| 688 | 
            -
                        self_reflection=self_reflection,
         | 
| 689 | 
            -
                    )
         | 
| 690 | 
            -
                    return answer
         | 
| 691 | 
            -
             | 
| 692 | 
            -
                def retrieval(
         | 
| 693 | 
            -
                    self,
         | 
| 694 | 
            -
                    model: Union[LLM, LMM, Agent],
         | 
| 695 | 
            -
                    question: str,
         | 
| 696 | 
            -
                    tools: Dict[int, Any],
         | 
| 697 | 
            -
                    previous_log: str,
         | 
| 698 | 
            -
                    reflections: str,
         | 
| 699 | 
            -
                ) -> Tuple[Dict, str]:
         | 
| 700 | 
            -
                    tool_id = choose_tool(
         | 
| 701 | 
            -
                        model,
         | 
| 702 | 
            -
                        question,
         | 
| 703 | 
            -
                        {k: v["description"] for k, v in tools.items()},
         | 
| 704 | 
            -
                        reflections,
         | 
| 705 | 
            -
                    )
         | 
| 706 | 
            -
                    if tool_id is None:
         | 
| 707 | 
            -
                        return {}, ""
         | 
| 708 | 
            -
             | 
| 709 | 
            -
                    tool_instructions = tools[tool_id]
         | 
| 710 | 
            -
                    tool_usage = tool_instructions["usage"]
         | 
| 711 | 
            -
                    tool_name = tool_instructions["name"]
         | 
| 424 | 
            +
                            if self.verbosity > 0:
         | 
| 425 | 
            +
                                self.log_progress(
         | 
| 426 | 
            +
                                    {
         | 
| 427 | 
            +
                                        "log": "Reflection:",
         | 
| 428 | 
            +
                                        "reflection": reflection,
         | 
| 429 | 
            +
                                    }
         | 
| 430 | 
            +
                                )
         | 
| 431 | 
            +
                                _LOGGER.info(f"Reflection: {reflection}")
         | 
| 432 | 
            +
                            feedback = cast(str, reflection["feedback"])
         | 
| 433 | 
            +
                            success = cast(bool, reflection["success"])
         | 
| 434 | 
            +
                            working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})
         | 
| 712 435 |  | 
| 713 | 
            -
             | 
| 714 | 
            -
                        model, question, tool_usage, previous_log, reflections
         | 
| 715 | 
            -
                    )
         | 
| 716 | 
            -
                    if parameters is None:
         | 
| 717 | 
            -
                        return {}, ""
         | 
| 718 | 
            -
                    tool_results = {
         | 
| 719 | 
            -
                        "task": question,
         | 
| 720 | 
            -
                        "tool_name": tool_name,
         | 
| 721 | 
            -
                        "parameters": parameters,
         | 
| 722 | 
            -
                    }
         | 
| 436 | 
            +
                        retries += 1
         | 
| 723 437 |  | 
| 724 438 | 
             
                    self.log_progress(
         | 
| 725 439 | 
             
                        {
         | 
| 726 | 
            -
                            "log": f" | 
| 727 | 
            -
             | 
| 440 | 
            +
                            "log": f"Vision Agent has concluded this chat.\nSuccess: {success}",
         | 
| 441 | 
            +
                            "finished": True,
         | 
| 728 442 | 
             
                        }
         | 
| 729 443 | 
             
                    )
         | 
| 730 444 |  | 
| 731 | 
            -
                     | 
| 732 | 
            -
                         | 
| 733 | 
            -
                         | 
| 734 | 
            -
             | 
| 735 | 
            -
             | 
| 736 | 
            -
             | 
| 737 | 
            -
             | 
| 738 | 
            -
                            for parameters in result["parameters"]:
         | 
| 739 | 
            -
                                call_results.append(
         | 
| 740 | 
            -
                                    function_call(tools[tool_id]["class"], parameters)
         | 
| 741 | 
            -
                                )
         | 
| 742 | 
            -
                        return call_results
         | 
| 743 | 
            -
             | 
| 744 | 
            -
                    call_results = parse_tool_results(tool_results)
         | 
| 745 | 
            -
                    tool_results["call_results"] = call_results
         | 
| 746 | 
            -
             | 
| 747 | 
            -
                    call_results_str = str(call_results)
         | 
| 748 | 
            -
                    return tool_results, call_results_str
         | 
| 445 | 
            +
                    return {
         | 
| 446 | 
            +
                        "code": code,
         | 
| 447 | 
            +
                        "test": test,
         | 
| 448 | 
            +
                        "test_result": results["test_result"],
         | 
| 449 | 
            +
                        "plan": plan,
         | 
| 450 | 
            +
                        "working_memory": working_memory,
         | 
| 451 | 
            +
                    }
         | 
| 749 452 |  | 
| 750 | 
            -
                def  | 
| 751 | 
            -
                    self | 
| 752 | 
            -
             | 
| 753 | 
            -
                     | 
| 754 | 
            -
                    tools: Dict[int, Any],
         | 
| 755 | 
            -
                    reflections: str,
         | 
| 756 | 
            -
                ) -> List[Dict]:
         | 
| 757 | 
            -
                    tasks = task_decompose(
         | 
| 758 | 
            -
                        task_model,
         | 
| 759 | 
            -
                        question,
         | 
| 760 | 
            -
                        {k: v["description"] for k, v in tools.items()},
         | 
| 761 | 
            -
                        reflections,
         | 
| 762 | 
            -
                    )
         | 
| 763 | 
            -
                    if tasks is not None:
         | 
| 764 | 
            -
                        task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
         | 
| 765 | 
            -
                        task_list = task_topology(task_model, question, task_list)
         | 
| 766 | 
            -
                        try:
         | 
| 767 | 
            -
                            task_list = topological_sort(task_list)
         | 
| 768 | 
            -
                        except Exception:
         | 
| 769 | 
            -
                            _LOGGER.error(f"Failed topological_sort on: {task_list}")
         | 
| 770 | 
            -
                    else:
         | 
| 771 | 
            -
                        task_list = []
         | 
| 772 | 
            -
                    self.log_progress(
         | 
| 773 | 
            -
                        {
         | 
| 774 | 
            -
                            "log": "Planned tasks:",
         | 
| 775 | 
            -
                            "plan": task_list,
         | 
| 776 | 
            -
                        }
         | 
| 777 | 
            -
                    )
         | 
| 778 | 
            -
                    return task_list
         | 
| 453 | 
            +
                def log_progress(self, data: Dict[str, Any]) -> None:
         | 
| 454 | 
            +
                    if self.report_progress_callback is not None:
         | 
| 455 | 
            +
                        self.report_progress_callback(data)
         | 
| 456 | 
            +
                    pass
         |