vision-agent 0.2.56__py3-none-any.whl → 0.2.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/__init__.py +1 -2
- vision_agent/agent/agent.py +3 -1
- vision_agent/agent/vision_agent.py +110 -81
- vision_agent/agent/vision_agent_prompts.py +1 -1
- vision_agent/lmm/__init__.py +1 -1
- vision_agent/lmm/lmm.py +54 -116
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/tools.py +3 -3
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.57.dist-info}/METADATA +36 -7
- vision_agent-0.2.57.dist-info/RECORD +23 -0
- vision_agent/agent/agent_coder.py +0 -216
- vision_agent/agent/agent_coder_prompts.py +0 -135
- vision_agent/agent/data_interpreter.py +0 -475
- vision_agent/agent/data_interpreter_prompts.py +0 -186
- vision_agent/agent/easytool.py +0 -346
- vision_agent/agent/easytool_prompts.py +0 -89
- vision_agent/agent/easytool_v2.py +0 -781
- vision_agent/agent/easytool_v2_prompts.py +0 -152
- vision_agent/agent/reflexion.py +0 -299
- vision_agent/agent/reflexion_prompts.py +0 -100
- vision_agent/llm/__init__.py +0 -1
- vision_agent/llm/llm.py +0 -176
- vision_agent/tools/easytool_tools.py +0 -1242
- vision_agent-0.2.56.dist-info/RECORD +0 -36
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.57.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.57.dist-info}/WHEEL +0 -0
@@ -1,781 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import sys
|
4
|
-
import tempfile
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
7
|
-
|
8
|
-
from PIL import Image
|
9
|
-
from tabulate import tabulate
|
10
|
-
|
11
|
-
from vision_agent.agent.agent import Agent
|
12
|
-
from vision_agent.agent.easytool_prompts import (
|
13
|
-
ANSWER_GENERATE,
|
14
|
-
ANSWER_SUMMARIZE,
|
15
|
-
CHOOSE_PARAMETER,
|
16
|
-
CHOOSE_TOOL,
|
17
|
-
TASK_DECOMPOSE,
|
18
|
-
TASK_TOPOLOGY,
|
19
|
-
)
|
20
|
-
from vision_agent.agent.easytool_v2_prompts import (
|
21
|
-
ANSWER_GENERATE_DEPENDS,
|
22
|
-
ANSWER_SUMMARIZE_DEPENDS,
|
23
|
-
CHOOSE_PARAMETER_DEPENDS,
|
24
|
-
CHOOSE_TOOL_DEPENDS,
|
25
|
-
TASK_DECOMPOSE_DEPENDS,
|
26
|
-
VISION_AGENT_REFLECTION,
|
27
|
-
)
|
28
|
-
from vision_agent.llm import LLM, OpenAILLM
|
29
|
-
from vision_agent.lmm import LMM, OpenAILMM
|
30
|
-
from vision_agent.tools.easytool_tools import TOOLS
|
31
|
-
from vision_agent.utils.image_utils import (
|
32
|
-
convert_to_b64,
|
33
|
-
overlay_bboxes,
|
34
|
-
overlay_heat_map,
|
35
|
-
overlay_masks,
|
36
|
-
)
|
37
|
-
|
38
|
-
logging.basicConfig(stream=sys.stdout)
|
39
|
-
_LOGGER = logging.getLogger(__name__)
|
40
|
-
_MAX_TABULATE_COL_WIDTH = 80
|
41
|
-
|
42
|
-
|
43
|
-
def parse_json(s: str) -> Any:
|
44
|
-
s = (
|
45
|
-
s.replace(": True", ": true")
|
46
|
-
.replace(": False", ": false")
|
47
|
-
.replace(":True", ": true")
|
48
|
-
.replace(":False", ": false")
|
49
|
-
.replace("```", "")
|
50
|
-
.strip()
|
51
|
-
)
|
52
|
-
return json.loads(s)
|
53
|
-
|
54
|
-
|
55
|
-
def change_name(name: str) -> str:
|
56
|
-
change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
|
57
|
-
if name in change_list:
|
58
|
-
name = "is_" + name.lower()
|
59
|
-
return name
|
60
|
-
|
61
|
-
|
62
|
-
def format_tools(tools: Dict[int, Any]) -> str:
|
63
|
-
# Format this way so it's clear what the ID's are
|
64
|
-
tool_str = ""
|
65
|
-
for key in tools:
|
66
|
-
tool_str += f"ID: {key} - {tools[key]}\n"
|
67
|
-
return tool_str
|
68
|
-
|
69
|
-
|
70
|
-
def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
|
71
|
-
usage = []
|
72
|
-
name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
|
73
|
-
for tool_res in tool_result:
|
74
|
-
if "tool_name" in tool_res:
|
75
|
-
usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))
|
76
|
-
|
77
|
-
usage_str = ""
|
78
|
-
for tool_name, tool_usage in usage:
|
79
|
-
usage_str += f"{tool_name} - {tool_usage}\n"
|
80
|
-
return usage_str
|
81
|
-
|
82
|
-
|
83
|
-
def topological_sort(tasks: List[Dict]) -> List[Dict]:
|
84
|
-
in_degree = {task["id"]: 0 for task in tasks}
|
85
|
-
for task in tasks:
|
86
|
-
for dep in task["dep"]:
|
87
|
-
if dep in in_degree:
|
88
|
-
in_degree[task["id"]] += 1
|
89
|
-
|
90
|
-
queue = [task for task in tasks if in_degree[task["id"]] == 0]
|
91
|
-
sorted_order = []
|
92
|
-
|
93
|
-
while queue:
|
94
|
-
current = queue.pop(0)
|
95
|
-
sorted_order.append(current)
|
96
|
-
|
97
|
-
for task in tasks:
|
98
|
-
if current["id"] in task["dep"]:
|
99
|
-
in_degree[task["id"]] -= 1
|
100
|
-
if in_degree[task["id"]] == 0:
|
101
|
-
queue.append(task)
|
102
|
-
|
103
|
-
if len(sorted_order) != len(tasks):
|
104
|
-
completed_ids = set([task["id"] for task in sorted_order])
|
105
|
-
remaining_tasks = [task for task in tasks if task["id"] not in completed_ids]
|
106
|
-
sorted_order.extend(remaining_tasks)
|
107
|
-
return sorted_order
|
108
|
-
|
109
|
-
|
110
|
-
def task_decompose(
|
111
|
-
model: Union[LLM, LMM, Agent],
|
112
|
-
question: str,
|
113
|
-
tools: Dict[int, Any],
|
114
|
-
reflections: str,
|
115
|
-
) -> Optional[Dict]:
|
116
|
-
if reflections:
|
117
|
-
prompt = TASK_DECOMPOSE_DEPENDS.format(
|
118
|
-
question=question, tools=format_tools(tools), reflections=reflections
|
119
|
-
)
|
120
|
-
else:
|
121
|
-
prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools))
|
122
|
-
tries = 0
|
123
|
-
str_result = ""
|
124
|
-
while True:
|
125
|
-
try:
|
126
|
-
str_result = model(prompt)
|
127
|
-
result = parse_json(str_result)
|
128
|
-
return result["Tasks"] # type: ignore
|
129
|
-
except Exception:
|
130
|
-
if tries > 10:
|
131
|
-
_LOGGER.error(f"Failed task_decompose on: {str_result}")
|
132
|
-
return None
|
133
|
-
tries += 1
|
134
|
-
continue
|
135
|
-
|
136
|
-
|
137
|
-
def task_topology(
|
138
|
-
model: Union[LLM, LMM, Agent], question: str, task_list: List[Dict]
|
139
|
-
) -> List[Dict[str, Any]]:
|
140
|
-
prompt = TASK_TOPOLOGY.format(question=question, task_list=task_list)
|
141
|
-
tries = 0
|
142
|
-
str_result = ""
|
143
|
-
while True:
|
144
|
-
try:
|
145
|
-
str_result = model(prompt)
|
146
|
-
result = parse_json(str_result)
|
147
|
-
for elt in result["Tasks"]:
|
148
|
-
if isinstance(elt["dep"], str):
|
149
|
-
elt["dep"] = [int(dep) for dep in elt["dep"].split(",")]
|
150
|
-
elif isinstance(elt["dep"], int):
|
151
|
-
elt["dep"] = [elt["dep"]]
|
152
|
-
elif isinstance(elt["dep"], list):
|
153
|
-
elt["dep"] = [int(dep) for dep in elt["dep"]]
|
154
|
-
return result["Tasks"] # type: ignore
|
155
|
-
except Exception:
|
156
|
-
if tries > 10:
|
157
|
-
_LOGGER.error(f"Failed task_topology on: {str_result}")
|
158
|
-
return task_list
|
159
|
-
tries += 1
|
160
|
-
continue
|
161
|
-
|
162
|
-
|
163
|
-
def choose_tool(
|
164
|
-
model: Union[LLM, LMM, Agent],
|
165
|
-
question: str,
|
166
|
-
tools: Dict[int, Any],
|
167
|
-
reflections: str,
|
168
|
-
) -> Optional[int]:
|
169
|
-
if reflections:
|
170
|
-
prompt = CHOOSE_TOOL_DEPENDS.format(
|
171
|
-
question=question, tools=format_tools(tools), reflections=reflections
|
172
|
-
)
|
173
|
-
else:
|
174
|
-
prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools))
|
175
|
-
tries = 0
|
176
|
-
str_result = ""
|
177
|
-
while True:
|
178
|
-
try:
|
179
|
-
str_result = model(prompt)
|
180
|
-
result = parse_json(str_result)
|
181
|
-
return result["ID"] # type: ignore
|
182
|
-
except Exception:
|
183
|
-
if tries > 10:
|
184
|
-
_LOGGER.error(f"Failed choose_tool on: {str_result}")
|
185
|
-
return None
|
186
|
-
tries += 1
|
187
|
-
continue
|
188
|
-
|
189
|
-
|
190
|
-
def choose_parameter(
|
191
|
-
model: Union[LLM, LMM, Agent],
|
192
|
-
question: str,
|
193
|
-
tool_usage: Dict,
|
194
|
-
previous_log: str,
|
195
|
-
reflections: str,
|
196
|
-
) -> Optional[Any]:
|
197
|
-
# TODO: should format tool_usage
|
198
|
-
if reflections:
|
199
|
-
prompt = CHOOSE_PARAMETER_DEPENDS.format(
|
200
|
-
question=question,
|
201
|
-
tool_usage=tool_usage,
|
202
|
-
previous_log=previous_log,
|
203
|
-
reflections=reflections,
|
204
|
-
)
|
205
|
-
else:
|
206
|
-
prompt = CHOOSE_PARAMETER.format(
|
207
|
-
question=question, tool_usage=tool_usage, previous_log=previous_log
|
208
|
-
)
|
209
|
-
tries = 0
|
210
|
-
str_result = ""
|
211
|
-
while True:
|
212
|
-
try:
|
213
|
-
str_result = model(prompt)
|
214
|
-
result = parse_json(str_result)
|
215
|
-
return result["Parameters"]
|
216
|
-
except Exception:
|
217
|
-
if tries > 10:
|
218
|
-
_LOGGER.error(f"Failed choose_parameter on: {str_result}")
|
219
|
-
return None
|
220
|
-
tries += 1
|
221
|
-
continue
|
222
|
-
|
223
|
-
|
224
|
-
def answer_generate(
|
225
|
-
model: Union[LLM, LMM, Agent],
|
226
|
-
question: str,
|
227
|
-
call_results: str,
|
228
|
-
previous_log: str,
|
229
|
-
reflections: str,
|
230
|
-
) -> str:
|
231
|
-
if reflections:
|
232
|
-
prompt = ANSWER_GENERATE_DEPENDS.format(
|
233
|
-
question=question,
|
234
|
-
call_results=call_results,
|
235
|
-
previous_log=previous_log,
|
236
|
-
reflections=reflections,
|
237
|
-
)
|
238
|
-
else:
|
239
|
-
prompt = ANSWER_GENERATE.format(
|
240
|
-
question=question, call_results=call_results, previous_log=previous_log
|
241
|
-
)
|
242
|
-
return model(prompt)
|
243
|
-
|
244
|
-
|
245
|
-
def answer_summarize(
|
246
|
-
model: Union[LLM, LMM, Agent], question: str, answers: List[Dict], reflections: str
|
247
|
-
) -> str:
|
248
|
-
if reflections:
|
249
|
-
prompt = ANSWER_SUMMARIZE_DEPENDS.format(
|
250
|
-
question=question, answers=answers, reflections=reflections
|
251
|
-
)
|
252
|
-
else:
|
253
|
-
prompt = ANSWER_SUMMARIZE.format(question=question, answers=answers)
|
254
|
-
return model(prompt)
|
255
|
-
|
256
|
-
|
257
|
-
def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
|
258
|
-
try:
|
259
|
-
return tool()(**parameters)
|
260
|
-
except Exception as e:
|
261
|
-
_LOGGER.error(f"Failed function_call on: {e}")
|
262
|
-
# return error message so it can self-correct
|
263
|
-
return str(e)
|
264
|
-
|
265
|
-
|
266
|
-
def self_reflect(
|
267
|
-
reflect_model: Union[LLM, LMM],
|
268
|
-
question: str,
|
269
|
-
tools: Dict[int, Any],
|
270
|
-
tool_result: List[Dict],
|
271
|
-
final_answer: str,
|
272
|
-
images: Optional[Sequence[Union[str, Path]]] = None,
|
273
|
-
) -> str:
|
274
|
-
prompt = VISION_AGENT_REFLECTION.format(
|
275
|
-
question=question,
|
276
|
-
tools=format_tools({k: v["description"] for k, v in tools.items()}),
|
277
|
-
tool_usage=format_tool_usage(tools, tool_result),
|
278
|
-
tool_results=str(tool_result),
|
279
|
-
final_answer=final_answer,
|
280
|
-
)
|
281
|
-
if (
|
282
|
-
issubclass(type(reflect_model), LMM)
|
283
|
-
and images is not None
|
284
|
-
and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
|
285
|
-
):
|
286
|
-
return reflect_model(prompt, images=images) # type: ignore
|
287
|
-
return reflect_model(prompt)
|
288
|
-
|
289
|
-
|
290
|
-
def parse_reflect(reflect: str) -> Any:
|
291
|
-
reflect = reflect.strip()
|
292
|
-
try:
|
293
|
-
return parse_json(reflect)
|
294
|
-
except Exception:
|
295
|
-
_LOGGER.error(f"Failed parse json reflection: {reflect}")
|
296
|
-
# LMMs have a hard time following directions, so make the criteria less strict
|
297
|
-
finish = (
|
298
|
-
"finish" in reflect.lower() and len(reflect) < 100
|
299
|
-
) or "finish" in reflect.lower()[-10:]
|
300
|
-
return {"Finish": finish, "Reflection": reflect}
|
301
|
-
|
302
|
-
|
303
|
-
def _handle_extract_frames(
|
304
|
-
image_to_data: Dict[str, Dict], tool_result: Dict
|
305
|
-
) -> Dict[str, Dict]:
|
306
|
-
image_to_data = image_to_data.copy()
|
307
|
-
# handle extract_frames_ case, useful if it extracts frames but doesn't do
|
308
|
-
# any following processing
|
309
|
-
for video_file_output in tool_result["call_results"]:
|
310
|
-
# When the video tool is run with wrong parameters, exit the loop
|
311
|
-
if not isinstance(video_file_output, tuple) or len(video_file_output) < 2:
|
312
|
-
break
|
313
|
-
for frame, _ in video_file_output:
|
314
|
-
image = frame
|
315
|
-
if image not in image_to_data:
|
316
|
-
image_to_data[image] = {
|
317
|
-
"bboxes": [],
|
318
|
-
"masks": [],
|
319
|
-
"heat_map": [],
|
320
|
-
"labels": [],
|
321
|
-
"scores": [],
|
322
|
-
}
|
323
|
-
return image_to_data
|
324
|
-
|
325
|
-
|
326
|
-
def _handle_viz_tools(
|
327
|
-
image_to_data: Dict[str, Dict], tool_result: Dict
|
328
|
-
) -> Dict[str, Dict]:
|
329
|
-
image_to_data = image_to_data.copy()
|
330
|
-
|
331
|
-
# handle grounding_sam_ and grounding_dino_
|
332
|
-
parameters = tool_result["parameters"]
|
333
|
-
# parameters can either be a dictionary or list, parameters can also be malformed
|
334
|
-
# becaus the LLM builds them
|
335
|
-
if isinstance(parameters, dict):
|
336
|
-
if "image" not in parameters:
|
337
|
-
return image_to_data
|
338
|
-
parameters = [parameters]
|
339
|
-
elif isinstance(tool_result["parameters"], list):
|
340
|
-
if len(tool_result["parameters"]) < 1 or (
|
341
|
-
"image" not in tool_result["parameters"][0]
|
342
|
-
):
|
343
|
-
return image_to_data
|
344
|
-
|
345
|
-
for param, call_result in zip(parameters, tool_result["call_results"]):
|
346
|
-
# Calls can fail, so we need to check if the call was successful. It can either:
|
347
|
-
# 1. return a str or some error that's not a dictionary
|
348
|
-
# 2. return a dictionary but not have the necessary keys
|
349
|
-
|
350
|
-
if not isinstance(call_result, dict) or (
|
351
|
-
"bboxes" not in call_result
|
352
|
-
and "mask" not in call_result
|
353
|
-
and "heat_map" not in call_result
|
354
|
-
):
|
355
|
-
return image_to_data
|
356
|
-
|
357
|
-
# if the call was successful, then we can add the image data
|
358
|
-
image = param["image"]
|
359
|
-
if image not in image_to_data:
|
360
|
-
image_to_data[image] = {
|
361
|
-
"bboxes": [],
|
362
|
-
"masks": [],
|
363
|
-
"heat_map": [],
|
364
|
-
"labels": [],
|
365
|
-
"scores": [],
|
366
|
-
}
|
367
|
-
|
368
|
-
image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
|
369
|
-
image_to_data[image]["labels"].extend(call_result.get("labels", []))
|
370
|
-
image_to_data[image]["scores"].extend(call_result.get("scores", []))
|
371
|
-
image_to_data[image]["masks"].extend(call_result.get("masks", []))
|
372
|
-
# only single heatmap is returned
|
373
|
-
if "heat_map" in call_result:
|
374
|
-
image_to_data[image]["heat_map"].append(call_result["heat_map"])
|
375
|
-
if "mask_shape" in call_result:
|
376
|
-
image_to_data[image]["mask_shape"] = call_result["mask_shape"]
|
377
|
-
|
378
|
-
return image_to_data
|
379
|
-
|
380
|
-
|
381
|
-
def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence:
|
382
|
-
if n <= 0:
|
383
|
-
return []
|
384
|
-
elif len(lst) == 0:
|
385
|
-
return []
|
386
|
-
elif n == 1:
|
387
|
-
return [lst[0]]
|
388
|
-
elif n >= len(lst):
|
389
|
-
return lst
|
390
|
-
|
391
|
-
spacing = (len(lst) - 1) / (n - 1)
|
392
|
-
return [lst[round(spacing * i)] for i in range(n)]
|
393
|
-
|
394
|
-
|
395
|
-
def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
|
396
|
-
image_to_data: Dict[str, Dict] = {}
|
397
|
-
for tool_result in all_tool_results:
|
398
|
-
# only handle bbox/mask tools or frame extraction
|
399
|
-
if tool_result["tool_name"] not in [
|
400
|
-
"grounding_sam_",
|
401
|
-
"grounding_dino_",
|
402
|
-
"extract_frames_",
|
403
|
-
"dinov_",
|
404
|
-
"zero_shot_counting_",
|
405
|
-
"visual_prompt_counting_",
|
406
|
-
"ocr_",
|
407
|
-
]:
|
408
|
-
continue
|
409
|
-
|
410
|
-
if tool_result["tool_name"] == "extract_frames_":
|
411
|
-
image_to_data = _handle_extract_frames(image_to_data, tool_result)
|
412
|
-
else:
|
413
|
-
image_to_data = _handle_viz_tools(image_to_data, tool_result)
|
414
|
-
|
415
|
-
visualized_images = []
|
416
|
-
for image_str in image_to_data:
|
417
|
-
image_path = Path(image_str)
|
418
|
-
image_data = image_to_data[image_str]
|
419
|
-
if "_counting_" in tool_result["tool_name"]:
|
420
|
-
image = overlay_heat_map(image_path, image_data)
|
421
|
-
else:
|
422
|
-
image = overlay_masks(image_path, image_data)
|
423
|
-
image = overlay_bboxes(image, image_data)
|
424
|
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
425
|
-
image.save(f.name)
|
426
|
-
visualized_images.append(f.name)
|
427
|
-
return visualized_images
|
428
|
-
|
429
|
-
|
430
|
-
class EasyToolV2(Agent):
|
431
|
-
"""EasyToolV2 is an agent framework that utilizes tools as well as self reflection
|
432
|
-
to accomplish tasks, in particular vision tasks. EasyToolV2 is based off of EasyTool
|
433
|
-
https://arxiv.org/abs/2401.06201 and Reflexion https://arxiv.org/abs/2303.11366
|
434
|
-
where it will attempt to complete a task and then reflect on whether or not it was
|
435
|
-
able to accomplish the task based off of the plan and final results, if not it will
|
436
|
-
redo the task with this newly added reflection.
|
437
|
-
|
438
|
-
Example
|
439
|
-
-------
|
440
|
-
>>> from vision_agent.agent import EasyToolV2
|
441
|
-
>>> agent = EasyToolV2()
|
442
|
-
>>> resp = agent("If red tomatoes cost $5 each and yellow tomatoes cost $2.50 each, what is the total cost of all the tomatoes in the image?", image="tomatoes.jpg")
|
443
|
-
>>> print(resp)
|
444
|
-
"The total cost is $57.50."
|
445
|
-
"""
|
446
|
-
|
447
|
-
def __init__(
|
448
|
-
self,
|
449
|
-
task_model: Optional[Union[LLM, LMM]] = None,
|
450
|
-
answer_model: Optional[Union[LLM, LMM]] = None,
|
451
|
-
reflect_model: Optional[Union[LLM, LMM]] = None,
|
452
|
-
max_retries: int = 2,
|
453
|
-
verbose: bool = False,
|
454
|
-
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
455
|
-
):
|
456
|
-
"""EasyToolV2 constructor.
|
457
|
-
|
458
|
-
Parameters:
|
459
|
-
task_model: the model to use for task decomposition.
|
460
|
-
answer_model: the model to use for reasoning and concluding the answer.
|
461
|
-
reflect_model: the model to use for self reflection.
|
462
|
-
max_retries: maximum number of retries to attempt to complete the task.
|
463
|
-
verbose: whether to print more logs.
|
464
|
-
report_progress_callback: a callback to report the progress of the agent.
|
465
|
-
This is useful for streaming logs in a web application where multiple
|
466
|
-
EasyToolV2 instances are running in parallel. This callback ensures
|
467
|
-
that the progress are not mixed up.
|
468
|
-
"""
|
469
|
-
self.task_model = (
|
470
|
-
OpenAILLM(model_name="gpt-4-turbo", json_mode=True, temperature=0.0)
|
471
|
-
if task_model is None
|
472
|
-
else task_model
|
473
|
-
)
|
474
|
-
self.answer_model = (
|
475
|
-
OpenAILLM(model_name="gpt-4-turbo", temperature=0.0)
|
476
|
-
if answer_model is None
|
477
|
-
else answer_model
|
478
|
-
)
|
479
|
-
self.reflect_model = (
|
480
|
-
OpenAILMM(model_name="gpt-4-turbo", json_mode=True, temperature=0.0)
|
481
|
-
if reflect_model is None
|
482
|
-
else reflect_model
|
483
|
-
)
|
484
|
-
self.max_retries = max_retries
|
485
|
-
self.tools = TOOLS
|
486
|
-
self.report_progress_callback = report_progress_callback
|
487
|
-
if verbose:
|
488
|
-
_LOGGER.setLevel(logging.INFO)
|
489
|
-
|
490
|
-
def __call__(
|
491
|
-
self,
|
492
|
-
input: Union[List[Dict[str, str]], str],
|
493
|
-
media: Optional[Union[str, Path]] = None,
|
494
|
-
reference_data: Optional[Dict[str, str]] = None,
|
495
|
-
visualize_output: Optional[bool] = False,
|
496
|
-
self_reflection: Optional[bool] = True,
|
497
|
-
) -> str:
|
498
|
-
"""Invoke the vision agent.
|
499
|
-
|
500
|
-
Parameters:
|
501
|
-
input: A conversation in the format of
|
502
|
-
[{"role": "user", "content": "describe your task here..."}] or a string
|
503
|
-
containing just the content.
|
504
|
-
media: The input media referenced in the chat parameter.
|
505
|
-
reference_data: A dictionary containing the reference image, mask or bounding
|
506
|
-
box in the format of:
|
507
|
-
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
|
508
|
-
where the bounding box coordinates are normalized.
|
509
|
-
visualize_output: Whether to visualize the output.
|
510
|
-
self_reflection: boolean to enable and disable self reflection.
|
511
|
-
|
512
|
-
Returns:
|
513
|
-
The result of the vision agent in text.
|
514
|
-
"""
|
515
|
-
if isinstance(input, str):
|
516
|
-
input = [{"role": "user", "content": input}]
|
517
|
-
return self.chat(
|
518
|
-
input,
|
519
|
-
media=media,
|
520
|
-
visualize_output=visualize_output,
|
521
|
-
reference_data=reference_data,
|
522
|
-
self_reflection=self_reflection,
|
523
|
-
)
|
524
|
-
|
525
|
-
def log_progress(self, data: Dict[str, Any]) -> None:
|
526
|
-
_LOGGER.info(data)
|
527
|
-
if self.report_progress_callback:
|
528
|
-
self.report_progress_callback(data)
|
529
|
-
|
530
|
-
def _report_visualization_via_callback(
|
531
|
-
self, images: Sequence[Union[str, Path]]
|
532
|
-
) -> None:
|
533
|
-
"""This is intended for streaming the visualization images via the callback to the client side."""
|
534
|
-
if self.report_progress_callback:
|
535
|
-
self.report_progress_callback({"log": "<VIZ>"})
|
536
|
-
if images:
|
537
|
-
for img in images:
|
538
|
-
self.report_progress_callback(
|
539
|
-
{"log": f"<IMG>base:64{convert_to_b64(img)}</IMG>"}
|
540
|
-
)
|
541
|
-
self.report_progress_callback({"log": "</VIZ>"})
|
542
|
-
|
543
|
-
def chat_with_workflow(
|
544
|
-
self,
|
545
|
-
chat: List[Dict[str, str]],
|
546
|
-
media: Optional[Union[str, Path]] = None,
|
547
|
-
reference_data: Optional[Dict[str, str]] = None,
|
548
|
-
visualize_output: Optional[bool] = False,
|
549
|
-
self_reflection: Optional[bool] = True,
|
550
|
-
) -> Tuple[str, List[Dict]]:
|
551
|
-
"""Chat with EasyToolV2 and return the final answer and all tool results.
|
552
|
-
|
553
|
-
Parameters:
|
554
|
-
chat: A conversation in the format of
|
555
|
-
[{"role": "user", "content": "describe your task here..."}].
|
556
|
-
media: The media image referenced in the chat parameter.
|
557
|
-
reference_data: A dictionary containing the reference image, mask or bounding
|
558
|
-
box in the format of:
|
559
|
-
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
|
560
|
-
where the bounding box coordinates are normalized.
|
561
|
-
visualize_output: Whether to visualize the output.
|
562
|
-
self_reflection: boolean to enable and disable self reflection.
|
563
|
-
|
564
|
-
Returns:
|
565
|
-
Tuple[str, List[Dict]]: A tuple where the first item is the final answer
|
566
|
-
and the second item is a list of all the tool results.
|
567
|
-
"""
|
568
|
-
if len(chat) == 0:
|
569
|
-
raise ValueError("Input cannot be empty.")
|
570
|
-
|
571
|
-
question = chat[0]["content"]
|
572
|
-
if media:
|
573
|
-
question += f" Image name: {media}"
|
574
|
-
if reference_data:
|
575
|
-
question += (
|
576
|
-
f" Reference image: {reference_data['image']}"
|
577
|
-
if "image" in reference_data
|
578
|
-
else ""
|
579
|
-
)
|
580
|
-
question += (
|
581
|
-
f" Reference mask: {reference_data['mask']}"
|
582
|
-
if "mask" in reference_data
|
583
|
-
else ""
|
584
|
-
)
|
585
|
-
question += (
|
586
|
-
f" Reference bbox: {reference_data['bbox']}"
|
587
|
-
if "bbox" in reference_data
|
588
|
-
else ""
|
589
|
-
)
|
590
|
-
|
591
|
-
reflections = ""
|
592
|
-
final_answer = ""
|
593
|
-
all_tool_results: List[Dict] = []
|
594
|
-
|
595
|
-
for _ in range(self.max_retries):
|
596
|
-
task_list = self.create_tasks(
|
597
|
-
self.task_model, question, self.tools, reflections
|
598
|
-
)
|
599
|
-
|
600
|
-
task_depend = {"Original Question": question}
|
601
|
-
previous_log = ""
|
602
|
-
answers = []
|
603
|
-
for task in task_list:
|
604
|
-
task_depend[task["id"]] = {"task": task["task"], "answer": "", "call_result": ""} # type: ignore
|
605
|
-
all_tool_results = []
|
606
|
-
|
607
|
-
for task in task_list:
|
608
|
-
task_str = task["task"]
|
609
|
-
previous_log = str(task_depend)
|
610
|
-
tool_results, call_results = self.retrieval(
|
611
|
-
self.task_model,
|
612
|
-
task_str,
|
613
|
-
self.tools,
|
614
|
-
previous_log,
|
615
|
-
reflections,
|
616
|
-
)
|
617
|
-
answer = answer_generate(
|
618
|
-
self.answer_model, task_str, call_results, previous_log, reflections
|
619
|
-
)
|
620
|
-
|
621
|
-
tool_results["answer"] = answer
|
622
|
-
all_tool_results.append(tool_results)
|
623
|
-
|
624
|
-
self.log_progress({"log": f"\tCall Result: {call_results}"})
|
625
|
-
self.log_progress({"log": f"\tAnswer: {answer}"})
|
626
|
-
answers.append({"task": task_str, "answer": answer})
|
627
|
-
task_depend[task["id"]]["answer"] = answer # type: ignore
|
628
|
-
task_depend[task["id"]]["call_result"] = call_results # type: ignore
|
629
|
-
final_answer = answer_summarize(
|
630
|
-
self.answer_model, question, answers, reflections
|
631
|
-
)
|
632
|
-
visualized_output = visualize_result(all_tool_results)
|
633
|
-
all_tool_results.append({"visualized_output": visualized_output})
|
634
|
-
if len(visualized_output) > 0:
|
635
|
-
reflection_images = sample_n_evenly_spaced(visualized_output, 3)
|
636
|
-
elif media is not None:
|
637
|
-
reflection_images = [media]
|
638
|
-
else:
|
639
|
-
reflection_images = None
|
640
|
-
|
641
|
-
if self_reflection:
|
642
|
-
reflection = self_reflect(
|
643
|
-
self.reflect_model,
|
644
|
-
question,
|
645
|
-
self.tools,
|
646
|
-
all_tool_results,
|
647
|
-
final_answer,
|
648
|
-
reflection_images,
|
649
|
-
)
|
650
|
-
self.log_progress({"log": f"Reflection: {reflection}"})
|
651
|
-
parsed_reflection = parse_reflect(reflection)
|
652
|
-
if parsed_reflection["Finish"]:
|
653
|
-
break
|
654
|
-
else:
|
655
|
-
reflections += "\n" + parsed_reflection["Reflection"]
|
656
|
-
else:
|
657
|
-
self.log_progress(
|
658
|
-
{"log": "Self Reflection skipped based on user request."}
|
659
|
-
)
|
660
|
-
break
|
661
|
-
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
|
662
|
-
self.log_progress(
|
663
|
-
{
|
664
|
-
"log": f"EasyToolV2 has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
|
665
|
-
}
|
666
|
-
)
|
667
|
-
|
668
|
-
if visualize_output:
|
669
|
-
viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][
|
670
|
-
"visualized_output"
|
671
|
-
]
|
672
|
-
self._report_visualization_via_callback(viz_images)
|
673
|
-
for img in viz_images:
|
674
|
-
Image.open(img).show()
|
675
|
-
|
676
|
-
return final_answer, all_tool_results
|
677
|
-
|
678
|
-
def chat(
|
679
|
-
self,
|
680
|
-
chat: List[Dict[str, str]],
|
681
|
-
media: Optional[Union[str, Path]] = None,
|
682
|
-
reference_data: Optional[Dict[str, str]] = None,
|
683
|
-
visualize_output: Optional[bool] = False,
|
684
|
-
self_reflection: Optional[bool] = True,
|
685
|
-
) -> str:
|
686
|
-
answer, _ = self.chat_with_workflow(
|
687
|
-
chat,
|
688
|
-
media=media,
|
689
|
-
visualize_output=visualize_output,
|
690
|
-
reference_data=reference_data,
|
691
|
-
self_reflection=self_reflection,
|
692
|
-
)
|
693
|
-
return answer
|
694
|
-
|
695
|
-
def retrieval(
|
696
|
-
self,
|
697
|
-
model: Union[LLM, LMM, Agent],
|
698
|
-
question: str,
|
699
|
-
tools: Dict[int, Any],
|
700
|
-
previous_log: str,
|
701
|
-
reflections: str,
|
702
|
-
) -> Tuple[Dict, str]:
|
703
|
-
tool_id = choose_tool(
|
704
|
-
model,
|
705
|
-
question,
|
706
|
-
{k: v["description"] for k, v in tools.items()},
|
707
|
-
reflections,
|
708
|
-
)
|
709
|
-
if tool_id is None:
|
710
|
-
return {}, ""
|
711
|
-
|
712
|
-
tool_instructions = tools[tool_id]
|
713
|
-
tool_usage = tool_instructions["usage"]
|
714
|
-
tool_name = tool_instructions["name"]
|
715
|
-
|
716
|
-
parameters = choose_parameter(
|
717
|
-
model, question, tool_usage, previous_log, reflections
|
718
|
-
)
|
719
|
-
if parameters is None:
|
720
|
-
return {}, ""
|
721
|
-
tool_results = {
|
722
|
-
"task": question,
|
723
|
-
"tool_name": tool_name,
|
724
|
-
"parameters": parameters,
|
725
|
-
}
|
726
|
-
|
727
|
-
self.log_progress(
|
728
|
-
{
|
729
|
-
"log": f"""Going to run the following tool(s) in sequence:
|
730
|
-
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
|
731
|
-
}
|
732
|
-
)
|
733
|
-
|
734
|
-
def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
|
735
|
-
call_results: List[Any] = []
|
736
|
-
if isinstance(result["parameters"], Dict):
|
737
|
-
call_results.append(
|
738
|
-
function_call(tools[tool_id]["class"], result["parameters"])
|
739
|
-
)
|
740
|
-
elif isinstance(result["parameters"], List):
|
741
|
-
for parameters in result["parameters"]:
|
742
|
-
call_results.append(
|
743
|
-
function_call(tools[tool_id]["class"], parameters)
|
744
|
-
)
|
745
|
-
return call_results
|
746
|
-
|
747
|
-
call_results = parse_tool_results(tool_results)
|
748
|
-
tool_results["call_results"] = call_results
|
749
|
-
|
750
|
-
call_results_str = str(call_results)
|
751
|
-
return tool_results, call_results_str
|
752
|
-
|
753
|
-
def create_tasks(
|
754
|
-
self,
|
755
|
-
task_model: Union[LLM, LMM],
|
756
|
-
question: str,
|
757
|
-
tools: Dict[int, Any],
|
758
|
-
reflections: str,
|
759
|
-
) -> List[Dict]:
|
760
|
-
tasks = task_decompose(
|
761
|
-
task_model,
|
762
|
-
question,
|
763
|
-
{k: v["description"] for k, v in tools.items()},
|
764
|
-
reflections,
|
765
|
-
)
|
766
|
-
if tasks is not None:
|
767
|
-
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
|
768
|
-
task_list = task_topology(task_model, question, task_list)
|
769
|
-
try:
|
770
|
-
task_list = topological_sort(task_list)
|
771
|
-
except Exception:
|
772
|
-
_LOGGER.error(f"Failed topological_sort on: {task_list}")
|
773
|
-
else:
|
774
|
-
task_list = []
|
775
|
-
self.log_progress(
|
776
|
-
{
|
777
|
-
"log": "Planned tasks:",
|
778
|
-
"plan": task_list,
|
779
|
-
}
|
780
|
-
)
|
781
|
-
return task_list
|