vision-agent 0.2.90__py3-none-any.whl → 0.2.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/__init__.py +2 -1
- vision_agent/agent/agent.py +1 -1
- vision_agent/agent/agent_utils.py +43 -0
- vision_agent/agent/vision_agent.py +116 -824
- vision_agent/agent/vision_agent_coder.py +897 -0
- vision_agent/agent/vision_agent_coder_prompts.py +328 -0
- vision_agent/agent/vision_agent_prompts.py +89 -302
- vision_agent/lmm/__init__.py +2 -1
- vision_agent/lmm/lmm.py +3 -5
- vision_agent/lmm/types.py +5 -0
- vision_agent/tools/__init__.py +1 -0
- vision_agent/tools/meta_tools.py +402 -0
- vision_agent/tools/tool_utils.py +48 -2
- vision_agent/tools/tools.py +7 -49
- vision_agent/utils/execute.py +52 -76
- vision_agent/utils/image_utils.py +1 -1
- vision_agent/utils/type_defs.py +1 -1
- {vision_agent-0.2.90.dist-info → vision_agent-0.2.92.dist-info}/METADATA +42 -12
- vision_agent-0.2.92.dist-info/RECORD +29 -0
- vision_agent-0.2.90.dist-info/RECORD +0 -24
- {vision_agent-0.2.90.dist-info → vision_agent-0.2.92.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.90.dist-info → vision_agent-0.2.92.dist-info}/WHEEL +0 -0
@@ -0,0 +1,897 @@
|
|
1
|
+
import copy
|
2
|
+
import difflib
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
import tempfile
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
9
|
+
|
10
|
+
from PIL import Image
|
11
|
+
from rich.console import Console
|
12
|
+
from rich.style import Style
|
13
|
+
from rich.syntax import Syntax
|
14
|
+
from tabulate import tabulate
|
15
|
+
|
16
|
+
import vision_agent.tools as T
|
17
|
+
from vision_agent.agent import Agent
|
18
|
+
from vision_agent.agent.agent_utils import extract_code, extract_json
|
19
|
+
from vision_agent.agent.vision_agent_coder_prompts import (
|
20
|
+
CODE,
|
21
|
+
FIX_BUG,
|
22
|
+
FULL_TASK,
|
23
|
+
PICK_PLAN,
|
24
|
+
PLAN,
|
25
|
+
PREVIOUS_FAILED,
|
26
|
+
SIMPLE_TEST,
|
27
|
+
TEST_PLANS,
|
28
|
+
USER_REQ,
|
29
|
+
)
|
30
|
+
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
|
31
|
+
from vision_agent.utils import CodeInterpreterFactory, Execution
|
32
|
+
from vision_agent.utils.execute import CodeInterpreter
|
33
|
+
from vision_agent.utils.image_utils import b64_to_pil
|
34
|
+
from vision_agent.utils.sim import AzureSim, Sim
|
35
|
+
from vision_agent.utils.video import play_video
|
36
|
+
|
37
|
+
logging.basicConfig(stream=sys.stdout)
|
38
|
+
WORKSPACE = Path(os.getenv("WORKSPACE", ""))
|
39
|
+
_LOGGER = logging.getLogger(__name__)
|
40
|
+
_MAX_TABULATE_COL_WIDTH = 80
|
41
|
+
_CONSOLE = Console()
|
42
|
+
|
43
|
+
|
44
|
+
class DefaultImports:
|
45
|
+
"""Container for default imports used in the code execution."""
|
46
|
+
|
47
|
+
common_imports = [
|
48
|
+
"from typing import *",
|
49
|
+
"from pillow_heif import register_heif_opener",
|
50
|
+
"register_heif_opener()",
|
51
|
+
]
|
52
|
+
|
53
|
+
@staticmethod
|
54
|
+
def to_code_string() -> str:
|
55
|
+
return "\n".join(DefaultImports.common_imports + T.__new_tools__)
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def prepend_imports(code: str) -> str:
|
59
|
+
"""Run this method to prepend the default imports to the code.
|
60
|
+
NOTE: be sure to run this method after the custom tools have been registered.
|
61
|
+
"""
|
62
|
+
return DefaultImports.to_code_string() + "\n\n" + code
|
63
|
+
|
64
|
+
|
65
|
+
def get_diff(before: str, after: str) -> str:
|
66
|
+
return "".join(
|
67
|
+
difflib.unified_diff(
|
68
|
+
before.splitlines(keepends=True), after.splitlines(keepends=True)
|
69
|
+
)
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def format_memory(memory: List[Dict[str, str]]) -> str:
|
74
|
+
output_str = ""
|
75
|
+
for i, m in enumerate(memory):
|
76
|
+
output_str += f"### Feedback {i}:\n"
|
77
|
+
output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
|
78
|
+
output_str += f"Feedback {i}: {m['feedback']}\n\n"
|
79
|
+
if "edits" in m:
|
80
|
+
output_str += f"Edits {i}:\n{m['edits']}\n"
|
81
|
+
output_str += "\n"
|
82
|
+
|
83
|
+
return output_str
|
84
|
+
|
85
|
+
|
86
|
+
def format_plans(plans: Dict[str, Any]) -> str:
|
87
|
+
plan_str = ""
|
88
|
+
for k, v in plans.items():
|
89
|
+
plan_str += f"{k}:\n"
|
90
|
+
plan_str += "-" + "\n-".join([e["instructions"] for e in v])
|
91
|
+
|
92
|
+
return plan_str
|
93
|
+
|
94
|
+
|
95
|
+
def extract_image(
|
96
|
+
media: Optional[Sequence[Union[str, Path]]]
|
97
|
+
) -> Optional[Sequence[Union[str, Path]]]:
|
98
|
+
if media is None:
|
99
|
+
return None
|
100
|
+
|
101
|
+
new_media = []
|
102
|
+
for m in media:
|
103
|
+
m = Path(m)
|
104
|
+
extension = m.suffix
|
105
|
+
if extension in [".jpg", ".jpeg", ".png", ".bmp"]:
|
106
|
+
new_media.append(m)
|
107
|
+
elif extension in [".mp4", ".mov"]:
|
108
|
+
frames = T.extract_frames(m)
|
109
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
110
|
+
if len(frames) > 0:
|
111
|
+
Image.fromarray(frames[0][0]).save(tmp.name)
|
112
|
+
new_media.append(Path(tmp.name))
|
113
|
+
if len(new_media) == 0:
|
114
|
+
return None
|
115
|
+
return new_media
|
116
|
+
|
117
|
+
|
118
|
+
def write_plans(
|
119
|
+
chat: List[Message],
|
120
|
+
tool_desc: str,
|
121
|
+
working_memory: str,
|
122
|
+
model: LMM,
|
123
|
+
) -> Dict[str, Any]:
|
124
|
+
chat = copy.deepcopy(chat)
|
125
|
+
if chat[-1]["role"] != "user":
|
126
|
+
raise ValueError("Last chat message must be from the user.")
|
127
|
+
|
128
|
+
user_request = chat[-1]["content"]
|
129
|
+
context = USER_REQ.format(user_request=user_request)
|
130
|
+
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
|
131
|
+
chat[-1]["content"] = prompt
|
132
|
+
return extract_json(model.chat(chat))
|
133
|
+
|
134
|
+
|
135
|
+
def pick_plan(
|
136
|
+
chat: List[Message],
|
137
|
+
plans: Dict[str, Any],
|
138
|
+
tool_info: str,
|
139
|
+
model: LMM,
|
140
|
+
code_interpreter: CodeInterpreter,
|
141
|
+
media: List[str],
|
142
|
+
log_progress: Callable[[Dict[str, Any]], None],
|
143
|
+
verbosity: int = 0,
|
144
|
+
max_retries: int = 3,
|
145
|
+
) -> Tuple[str, str]:
|
146
|
+
log_progress(
|
147
|
+
{
|
148
|
+
"type": "log",
|
149
|
+
"log_content": "Generating code to pick the best plan",
|
150
|
+
"status": "started",
|
151
|
+
}
|
152
|
+
)
|
153
|
+
|
154
|
+
chat = copy.deepcopy(chat)
|
155
|
+
if chat[-1]["role"] != "user":
|
156
|
+
raise ValueError("Last chat message must be from the user.")
|
157
|
+
|
158
|
+
plan_str = format_plans(plans)
|
159
|
+
prompt = TEST_PLANS.format(
|
160
|
+
docstring=tool_info, plans=plan_str, previous_attempts="", media=media
|
161
|
+
)
|
162
|
+
|
163
|
+
code = extract_code(model(prompt))
|
164
|
+
log_progress(
|
165
|
+
{
|
166
|
+
"type": "log",
|
167
|
+
"log_content": "Executing code to test plans",
|
168
|
+
"code": DefaultImports.prepend_imports(code),
|
169
|
+
"status": "running",
|
170
|
+
}
|
171
|
+
)
|
172
|
+
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
|
173
|
+
tool_output_str = ""
|
174
|
+
if len(tool_output.logs.stdout) > 0:
|
175
|
+
tool_output_str = tool_output.logs.stdout[0]
|
176
|
+
|
177
|
+
if verbosity == 2:
|
178
|
+
_print_code("Initial code and tests:", code)
|
179
|
+
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")
|
180
|
+
|
181
|
+
log_progress(
|
182
|
+
{
|
183
|
+
"type": "log",
|
184
|
+
"log_content": (
|
185
|
+
"Code execution succeeded"
|
186
|
+
if tool_output.success
|
187
|
+
else "Code execution failed"
|
188
|
+
),
|
189
|
+
"payload": tool_output.to_json(),
|
190
|
+
"status": "completed" if tool_output.success else "failed",
|
191
|
+
}
|
192
|
+
)
|
193
|
+
|
194
|
+
# retry if the tool output is empty or code fails
|
195
|
+
count = 0
|
196
|
+
while (not tool_output.success or tool_output_str == "") and count < max_retries:
|
197
|
+
prompt = TEST_PLANS.format(
|
198
|
+
docstring=tool_info,
|
199
|
+
plans=plan_str,
|
200
|
+
previous_attempts=PREVIOUS_FAILED.format(
|
201
|
+
code=code, error=tool_output.text()
|
202
|
+
),
|
203
|
+
media=media,
|
204
|
+
)
|
205
|
+
log_progress(
|
206
|
+
{
|
207
|
+
"type": "log",
|
208
|
+
"log_content": "Retrying code to test plans",
|
209
|
+
"status": "running",
|
210
|
+
"code": DefaultImports.prepend_imports(code),
|
211
|
+
}
|
212
|
+
)
|
213
|
+
code = extract_code(model(prompt))
|
214
|
+
log_progress(
|
215
|
+
{
|
216
|
+
"type": "log",
|
217
|
+
"log_content": (
|
218
|
+
"Code execution succeeded"
|
219
|
+
if tool_output.success
|
220
|
+
else "Code execution failed"
|
221
|
+
),
|
222
|
+
"code": DefaultImports.prepend_imports(code),
|
223
|
+
"payload": tool_output.to_json(),
|
224
|
+
"status": "completed" if tool_output.success else "failed",
|
225
|
+
}
|
226
|
+
)
|
227
|
+
tool_output = code_interpreter.exec_isolation(
|
228
|
+
DefaultImports.prepend_imports(code)
|
229
|
+
)
|
230
|
+
tool_output_str = ""
|
231
|
+
if len(tool_output.logs.stdout) > 0:
|
232
|
+
tool_output_str = tool_output.logs.stdout[0]
|
233
|
+
|
234
|
+
if verbosity == 2:
|
235
|
+
_print_code("Code and test after attempted fix:", code)
|
236
|
+
_LOGGER.info(f"Code execution result after attempte {count}")
|
237
|
+
|
238
|
+
count += 1
|
239
|
+
|
240
|
+
if verbosity >= 1:
|
241
|
+
_print_code("Final code:", code)
|
242
|
+
|
243
|
+
user_req = chat[-1]["content"]
|
244
|
+
context = USER_REQ.format(user_request=user_req)
|
245
|
+
# because the tool picker model gets the image as well, we have to be careful with
|
246
|
+
# how much text we send it, so we truncate the tool output to 20,000 characters
|
247
|
+
prompt = PICK_PLAN.format(
|
248
|
+
context=context,
|
249
|
+
plans=format_plans(plans),
|
250
|
+
tool_output=tool_output_str[:20_000],
|
251
|
+
)
|
252
|
+
chat[-1]["content"] = prompt
|
253
|
+
best_plan = extract_json(model(chat))
|
254
|
+
|
255
|
+
if verbosity >= 1:
|
256
|
+
_LOGGER.info(f"Best plan:\n{best_plan}")
|
257
|
+
log_progress(
|
258
|
+
{
|
259
|
+
"type": "log",
|
260
|
+
"log_content": "Picked best plan",
|
261
|
+
"status": "completed",
|
262
|
+
"payload": plans[best_plan["best_plan"]],
|
263
|
+
}
|
264
|
+
)
|
265
|
+
return best_plan["best_plan"], tool_output_str
|
266
|
+
|
267
|
+
|
268
|
+
def write_code(
|
269
|
+
coder: LMM,
|
270
|
+
chat: List[Message],
|
271
|
+
plan: str,
|
272
|
+
tool_info: str,
|
273
|
+
tool_output: str,
|
274
|
+
feedback: str,
|
275
|
+
) -> str:
|
276
|
+
chat = copy.deepcopy(chat)
|
277
|
+
if chat[-1]["role"] != "user":
|
278
|
+
raise ValueError("Last chat message must be from the user.")
|
279
|
+
|
280
|
+
user_request = chat[-1]["content"]
|
281
|
+
prompt = CODE.format(
|
282
|
+
docstring=tool_info,
|
283
|
+
question=FULL_TASK.format(user_request=user_request, subtasks=plan),
|
284
|
+
tool_output=tool_output,
|
285
|
+
feedback=feedback,
|
286
|
+
)
|
287
|
+
chat[-1]["content"] = prompt
|
288
|
+
return extract_code(coder(chat))
|
289
|
+
|
290
|
+
|
291
|
+
def write_test(
|
292
|
+
tester: LMM,
|
293
|
+
chat: List[Message],
|
294
|
+
tool_utils: str,
|
295
|
+
code: str,
|
296
|
+
feedback: str,
|
297
|
+
media: Optional[Sequence[Union[str, Path]]] = None,
|
298
|
+
) -> str:
|
299
|
+
chat = copy.deepcopy(chat)
|
300
|
+
if chat[-1]["role"] != "user":
|
301
|
+
raise ValueError("Last chat message must be from the user.")
|
302
|
+
|
303
|
+
user_request = chat[-1]["content"]
|
304
|
+
prompt = SIMPLE_TEST.format(
|
305
|
+
docstring=tool_utils,
|
306
|
+
question=user_request,
|
307
|
+
code=code,
|
308
|
+
feedback=feedback,
|
309
|
+
media=media,
|
310
|
+
)
|
311
|
+
chat[-1]["content"] = prompt
|
312
|
+
return extract_code(tester(chat))
|
313
|
+
|
314
|
+
|
315
|
+
def write_and_test_code(
|
316
|
+
chat: List[Message],
|
317
|
+
plan: str,
|
318
|
+
tool_info: str,
|
319
|
+
tool_output: str,
|
320
|
+
tool_utils: str,
|
321
|
+
working_memory: List[Dict[str, str]],
|
322
|
+
coder: LMM,
|
323
|
+
tester: LMM,
|
324
|
+
debugger: LMM,
|
325
|
+
code_interpreter: CodeInterpreter,
|
326
|
+
log_progress: Callable[[Dict[str, Any]], None],
|
327
|
+
verbosity: int = 0,
|
328
|
+
max_retries: int = 3,
|
329
|
+
media: Optional[Sequence[Union[str, Path]]] = None,
|
330
|
+
) -> Dict[str, Any]:
|
331
|
+
log_progress(
|
332
|
+
{
|
333
|
+
"type": "log",
|
334
|
+
"log_content": "Generating code",
|
335
|
+
"status": "started",
|
336
|
+
}
|
337
|
+
)
|
338
|
+
code = write_code(
|
339
|
+
coder,
|
340
|
+
chat,
|
341
|
+
plan,
|
342
|
+
tool_info,
|
343
|
+
tool_output,
|
344
|
+
format_memory(working_memory),
|
345
|
+
)
|
346
|
+
test = write_test(
|
347
|
+
tester, chat, tool_utils, code, format_memory(working_memory), media
|
348
|
+
)
|
349
|
+
|
350
|
+
log_progress(
|
351
|
+
{
|
352
|
+
"type": "log",
|
353
|
+
"log_content": "Running code",
|
354
|
+
"status": "running",
|
355
|
+
"code": DefaultImports.prepend_imports(code),
|
356
|
+
"payload": {
|
357
|
+
"test": test,
|
358
|
+
},
|
359
|
+
}
|
360
|
+
)
|
361
|
+
result = code_interpreter.exec_isolation(
|
362
|
+
f"{DefaultImports.to_code_string()}\n{code}\n{test}"
|
363
|
+
)
|
364
|
+
log_progress(
|
365
|
+
{
|
366
|
+
"type": "log",
|
367
|
+
"log_content": (
|
368
|
+
"Code execution succeeded"
|
369
|
+
if result.success
|
370
|
+
else "Code execution failed"
|
371
|
+
),
|
372
|
+
"status": "completed" if result.success else "failed",
|
373
|
+
"code": DefaultImports.prepend_imports(code),
|
374
|
+
"payload": {
|
375
|
+
"test": test,
|
376
|
+
"result": result.to_json(),
|
377
|
+
},
|
378
|
+
}
|
379
|
+
)
|
380
|
+
if verbosity == 2:
|
381
|
+
_print_code("Initial code and tests:", code, test)
|
382
|
+
_LOGGER.info(
|
383
|
+
f"Initial code execution result:\n{result.text(include_logs=True)}"
|
384
|
+
)
|
385
|
+
|
386
|
+
count = 0
|
387
|
+
new_working_memory: List[Dict[str, str]] = []
|
388
|
+
while not result.success and count < max_retries:
|
389
|
+
if verbosity == 2:
|
390
|
+
_LOGGER.info(f"Start debugging attempt {count + 1}")
|
391
|
+
code, test, result = debug_code(
|
392
|
+
working_memory,
|
393
|
+
debugger,
|
394
|
+
code_interpreter,
|
395
|
+
code,
|
396
|
+
test,
|
397
|
+
result,
|
398
|
+
new_working_memory,
|
399
|
+
log_progress,
|
400
|
+
verbosity,
|
401
|
+
)
|
402
|
+
count += 1
|
403
|
+
|
404
|
+
if verbosity >= 1:
|
405
|
+
_print_code("Final code and tests:", code, test)
|
406
|
+
|
407
|
+
return {
|
408
|
+
"code": code,
|
409
|
+
"test": test,
|
410
|
+
"success": result.success,
|
411
|
+
"test_result": result,
|
412
|
+
"working_memory": new_working_memory,
|
413
|
+
}
|
414
|
+
|
415
|
+
|
416
|
+
def debug_code(
|
417
|
+
working_memory: List[Dict[str, str]],
|
418
|
+
debugger: LMM,
|
419
|
+
code_interpreter: CodeInterpreter,
|
420
|
+
code: str,
|
421
|
+
test: str,
|
422
|
+
result: Execution,
|
423
|
+
new_working_memory: List[Dict[str, str]],
|
424
|
+
log_progress: Callable[[Dict[str, Any]], None],
|
425
|
+
verbosity: int = 0,
|
426
|
+
) -> tuple[str, str, Execution]:
|
427
|
+
log_progress(
|
428
|
+
{
|
429
|
+
"type": "code",
|
430
|
+
"status": "started",
|
431
|
+
}
|
432
|
+
)
|
433
|
+
|
434
|
+
fixed_code_and_test = {"code": "", "test": "", "reflections": ""}
|
435
|
+
success = False
|
436
|
+
count = 0
|
437
|
+
while not success and count < 3:
|
438
|
+
try:
|
439
|
+
fixed_code_and_test = extract_json(
|
440
|
+
debugger(
|
441
|
+
FIX_BUG.format(
|
442
|
+
code=code,
|
443
|
+
tests=test,
|
444
|
+
result="\n".join(result.text().splitlines()[-50:]),
|
445
|
+
feedback=format_memory(working_memory + new_working_memory),
|
446
|
+
)
|
447
|
+
)
|
448
|
+
)
|
449
|
+
success = True
|
450
|
+
except Exception as e:
|
451
|
+
_LOGGER.exception(f"Error while extracting JSON: {e}")
|
452
|
+
|
453
|
+
count += 1
|
454
|
+
|
455
|
+
old_code = code
|
456
|
+
old_test = test
|
457
|
+
|
458
|
+
if fixed_code_and_test["code"].strip() != "":
|
459
|
+
code = extract_code(fixed_code_and_test["code"])
|
460
|
+
if fixed_code_and_test["test"].strip() != "":
|
461
|
+
test = extract_code(fixed_code_and_test["test"])
|
462
|
+
|
463
|
+
new_working_memory.append(
|
464
|
+
{
|
465
|
+
"code": f"{code}\n{test}",
|
466
|
+
"feedback": fixed_code_and_test["reflections"],
|
467
|
+
"edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
|
468
|
+
}
|
469
|
+
)
|
470
|
+
log_progress(
|
471
|
+
{
|
472
|
+
"type": "code",
|
473
|
+
"status": "running",
|
474
|
+
"payload": {
|
475
|
+
"code": DefaultImports.prepend_imports(code),
|
476
|
+
"test": test,
|
477
|
+
},
|
478
|
+
}
|
479
|
+
)
|
480
|
+
|
481
|
+
result = code_interpreter.exec_isolation(
|
482
|
+
f"{DefaultImports.to_code_string()}\n{code}\n{test}"
|
483
|
+
)
|
484
|
+
log_progress(
|
485
|
+
{
|
486
|
+
"type": "code",
|
487
|
+
"status": "completed" if result.success else "failed",
|
488
|
+
"payload": {
|
489
|
+
"code": DefaultImports.prepend_imports(code),
|
490
|
+
"test": test,
|
491
|
+
"result": result.to_json(),
|
492
|
+
},
|
493
|
+
}
|
494
|
+
)
|
495
|
+
if verbosity == 2:
|
496
|
+
_print_code("Code and test after attempted fix:", code, test)
|
497
|
+
_LOGGER.info(
|
498
|
+
f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
|
499
|
+
)
|
500
|
+
|
501
|
+
return code, test, result
|
502
|
+
|
503
|
+
|
504
|
+
def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
|
505
|
+
_CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
|
506
|
+
_CONSOLE.print("=" * 30 + " Code " + "=" * 30)
|
507
|
+
_CONSOLE.print(
|
508
|
+
Syntax(
|
509
|
+
DefaultImports.prepend_imports(code),
|
510
|
+
"python",
|
511
|
+
theme="gruvbox-dark",
|
512
|
+
line_numbers=True,
|
513
|
+
)
|
514
|
+
)
|
515
|
+
if test:
|
516
|
+
_CONSOLE.print("=" * 30 + " Test " + "=" * 30)
|
517
|
+
_CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
|
518
|
+
|
519
|
+
|
520
|
+
def retrieve_tools(
|
521
|
+
plans: Dict[str, List[Dict[str, str]]],
|
522
|
+
tool_recommender: Sim,
|
523
|
+
log_progress: Callable[[Dict[str, Any]], None],
|
524
|
+
verbosity: int = 0,
|
525
|
+
) -> Dict[str, str]:
|
526
|
+
log_progress(
|
527
|
+
{
|
528
|
+
"type": "tools",
|
529
|
+
"status": "started",
|
530
|
+
}
|
531
|
+
)
|
532
|
+
tool_info = []
|
533
|
+
tool_desc = []
|
534
|
+
tool_lists: Dict[str, List[Dict[str, str]]] = {}
|
535
|
+
for k, plan in plans.items():
|
536
|
+
tool_lists[k] = []
|
537
|
+
for task in plan:
|
538
|
+
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
|
539
|
+
tool_info.extend([e["doc"] for e in tools])
|
540
|
+
tool_desc.extend([e["desc"] for e in tools])
|
541
|
+
tool_lists[k].extend(
|
542
|
+
{"description": e["desc"], "documentation": e["doc"]} for e in tools
|
543
|
+
)
|
544
|
+
|
545
|
+
if verbosity == 2:
|
546
|
+
tool_desc_str = "\n".join(set(tool_desc))
|
547
|
+
_LOGGER.info(f"Tools Description:\n{tool_desc_str}")
|
548
|
+
|
549
|
+
tool_lists_unique = {}
|
550
|
+
for k in tool_lists:
|
551
|
+
tool_lists_unique[k] = "\n\n".join(
|
552
|
+
set(e["documentation"] for e in tool_lists[k])
|
553
|
+
)
|
554
|
+
all_tools = "\n\n".join(set(tool_info))
|
555
|
+
tool_lists_unique["all"] = all_tools
|
556
|
+
return tool_lists_unique
|
557
|
+
|
558
|
+
|
559
|
+
class VisionAgentCoder(Agent):
|
560
|
+
"""Vision Agent Coder is an agentic framework that can output code based on a user
|
561
|
+
request. It can plan tasks, retrieve relevant tools, write code, write tests and
|
562
|
+
reflect on failed test cases to debug code. It is inspired by AgentCoder
|
563
|
+
https://arxiv.org/abs/2312.13010 and Data Interpeter https://arxiv.org/abs/2402.18679
|
564
|
+
|
565
|
+
Example
|
566
|
+
-------
|
567
|
+
>>> from vision_agent.agent import VisionAgentCoder
|
568
|
+
>>> agent = VisionAgentCoder()
|
569
|
+
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
|
570
|
+
"""
|
571
|
+
|
572
|
+
def __init__(
|
573
|
+
self,
|
574
|
+
planner: Optional[LMM] = None,
|
575
|
+
coder: Optional[LMM] = None,
|
576
|
+
tester: Optional[LMM] = None,
|
577
|
+
debugger: Optional[LMM] = None,
|
578
|
+
tool_recommender: Optional[Sim] = None,
|
579
|
+
verbosity: int = 0,
|
580
|
+
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
581
|
+
code_sandbox_runtime: Optional[str] = None,
|
582
|
+
) -> None:
|
583
|
+
"""Initialize the Vision Agent Coder.
|
584
|
+
|
585
|
+
Parameters:
|
586
|
+
planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
|
587
|
+
coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
|
588
|
+
tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
|
589
|
+
debugger (Optional[LMM]): The debugger model to
|
590
|
+
tool_recommender (Optional[Sim]): The tool recommender model to use.
|
591
|
+
verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
|
592
|
+
highest verbosity level which will output all intermediate debugging
|
593
|
+
code.
|
594
|
+
report_progress_callback: a callback to report the progress of the agent.
|
595
|
+
This is useful for streaming logs in a web application where multiple
|
596
|
+
VisionAgentCoder instances are running in parallel. This callback
|
597
|
+
ensures that the progress are not mixed up.
|
598
|
+
code_sandbox_runtime: the code sandbox runtime to use. A code sandbox is
|
599
|
+
used to run the generated code. It can be one of the following
|
600
|
+
values: None, "local" or "e2b". If None, VisionAgentCoder will read
|
601
|
+
the value from the environment variable CODE_SANDBOX_RUNTIME. If it's
|
602
|
+
also None, the local python runtime environment will be used.
|
603
|
+
"""
|
604
|
+
|
605
|
+
self.planner = (
|
606
|
+
OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
|
607
|
+
)
|
608
|
+
self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
|
609
|
+
self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
|
610
|
+
self.debugger = (
|
611
|
+
OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
|
612
|
+
)
|
613
|
+
self.verbosity = verbosity
|
614
|
+
if self.verbosity > 0:
|
615
|
+
_LOGGER.setLevel(logging.INFO)
|
616
|
+
|
617
|
+
self.tool_recommender = (
|
618
|
+
Sim(T.TOOLS_DF, sim_key="desc")
|
619
|
+
if tool_recommender is None
|
620
|
+
else tool_recommender
|
621
|
+
)
|
622
|
+
self.report_progress_callback = report_progress_callback
|
623
|
+
self.code_sandbox_runtime = code_sandbox_runtime
|
624
|
+
|
625
|
+
def __call__(
|
626
|
+
self,
|
627
|
+
input: Union[str, List[Message]],
|
628
|
+
media: Optional[Union[str, Path]] = None,
|
629
|
+
) -> str:
|
630
|
+
"""Chat with VisionAgentCoder and return intermediate information regarding the
|
631
|
+
task.
|
632
|
+
|
633
|
+
Parameters:
|
634
|
+
input (Union[str, List[Message]]): A conversation in the format of
|
635
|
+
[{"role": "user", "content": "describe your task here..."}] or a string
|
636
|
+
of just the contents.
|
637
|
+
media (Optional[Union[str, Path]]): The media file to be used in the task.
|
638
|
+
|
639
|
+
Returns:
|
640
|
+
str: The code output by the VisionAgentCoder.
|
641
|
+
"""
|
642
|
+
|
643
|
+
if isinstance(input, str):
|
644
|
+
input = [{"role": "user", "content": input}]
|
645
|
+
if media is not None:
|
646
|
+
input[0]["media"] = [media]
|
647
|
+
results = self.chat_with_workflow(input)
|
648
|
+
results.pop("working_memory")
|
649
|
+
return results["code"] # type: ignore
|
650
|
+
|
651
|
+
def chat_with_workflow(
|
652
|
+
self,
|
653
|
+
chat: List[Message],
|
654
|
+
test_multi_plan: bool = True,
|
655
|
+
display_visualization: bool = False,
|
656
|
+
) -> Dict[str, Any]:
|
657
|
+
"""Chat with VisionAgentCoder and return intermediate information regarding the
|
658
|
+
task.
|
659
|
+
|
660
|
+
Parameters:
|
661
|
+
chat (List[Message]): A conversation
|
662
|
+
in the format of:
|
663
|
+
[{"role": "user", "content": "describe your task here..."}]
|
664
|
+
or if it contains media files, it should be in the format of:
|
665
|
+
[{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
|
666
|
+
test_multi_plan (bool): If True, it will test tools for multiple plans and
|
667
|
+
pick the best one based off of the tool results. If False, it will go
|
668
|
+
with the first plan.
|
669
|
+
display_visualization (bool): If True, it opens a new window locally to
|
670
|
+
show the image(s) created by visualization code (if there is any).
|
671
|
+
|
672
|
+
Returns:
|
673
|
+
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
|
674
|
+
and working memory of the agent.
|
675
|
+
"""
|
676
|
+
|
677
|
+
if not chat:
|
678
|
+
raise ValueError("Chat cannot be empty.")
|
679
|
+
|
680
|
+
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
|
681
|
+
with CodeInterpreterFactory.new_instance(
|
682
|
+
code_sandbox_runtime=self.code_sandbox_runtime
|
683
|
+
) as code_interpreter:
|
684
|
+
chat = copy.deepcopy(chat)
|
685
|
+
media_list = []
|
686
|
+
for chat_i in chat:
|
687
|
+
if "media" in chat_i:
|
688
|
+
for media in chat_i["media"]:
|
689
|
+
media = code_interpreter.upload_file(media)
|
690
|
+
chat_i["content"] += f" Media name {media}" # type: ignore
|
691
|
+
media_list.append(media)
|
692
|
+
|
693
|
+
int_chat = cast(
|
694
|
+
List[Message],
|
695
|
+
[
|
696
|
+
(
|
697
|
+
{
|
698
|
+
"role": c["role"],
|
699
|
+
"content": c["content"],
|
700
|
+
"media": c["media"],
|
701
|
+
}
|
702
|
+
if "media" in c
|
703
|
+
else {"role": c["role"], "content": c["content"]}
|
704
|
+
)
|
705
|
+
for c in chat
|
706
|
+
],
|
707
|
+
)
|
708
|
+
|
709
|
+
code = ""
|
710
|
+
test = ""
|
711
|
+
working_memory: List[Dict[str, str]] = []
|
712
|
+
results = {"code": "", "test": "", "plan": []}
|
713
|
+
plan = []
|
714
|
+
success = False
|
715
|
+
self.log_progress(
|
716
|
+
{
|
717
|
+
"type": "log",
|
718
|
+
"log_content": "Creating plans",
|
719
|
+
"status": "started",
|
720
|
+
}
|
721
|
+
)
|
722
|
+
plans = write_plans(
|
723
|
+
int_chat,
|
724
|
+
T.TOOL_DESCRIPTIONS,
|
725
|
+
format_memory(working_memory),
|
726
|
+
self.planner,
|
727
|
+
)
|
728
|
+
|
729
|
+
if self.verbosity >= 1:
|
730
|
+
for p in plans:
|
731
|
+
_LOGGER.info(
|
732
|
+
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
|
733
|
+
)
|
734
|
+
|
735
|
+
tool_infos = retrieve_tools(
|
736
|
+
plans,
|
737
|
+
self.tool_recommender,
|
738
|
+
self.log_progress,
|
739
|
+
self.verbosity,
|
740
|
+
)
|
741
|
+
|
742
|
+
if test_multi_plan:
|
743
|
+
best_plan, tool_output_str = pick_plan(
|
744
|
+
int_chat,
|
745
|
+
plans,
|
746
|
+
tool_infos["all"],
|
747
|
+
self.coder,
|
748
|
+
code_interpreter,
|
749
|
+
media_list,
|
750
|
+
self.log_progress,
|
751
|
+
verbosity=self.verbosity,
|
752
|
+
)
|
753
|
+
else:
|
754
|
+
best_plan = list(plans.keys())[0]
|
755
|
+
tool_output_str = ""
|
756
|
+
|
757
|
+
if best_plan in plans and best_plan in tool_infos:
|
758
|
+
plan_i = plans[best_plan]
|
759
|
+
tool_info = tool_infos[best_plan]
|
760
|
+
else:
|
761
|
+
if self.verbosity >= 1:
|
762
|
+
_LOGGER.warning(
|
763
|
+
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
|
764
|
+
)
|
765
|
+
k = list(plans.keys())[0]
|
766
|
+
plan_i = plans[k]
|
767
|
+
tool_info = tool_infos[k]
|
768
|
+
|
769
|
+
self.log_progress(
|
770
|
+
{
|
771
|
+
"type": "log",
|
772
|
+
"log_content": "Creating plans",
|
773
|
+
"status": "completed",
|
774
|
+
"payload": tool_info,
|
775
|
+
}
|
776
|
+
)
|
777
|
+
|
778
|
+
if self.verbosity >= 1:
|
779
|
+
_LOGGER.info(
|
780
|
+
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
|
781
|
+
)
|
782
|
+
|
783
|
+
results = write_and_test_code(
|
784
|
+
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
|
785
|
+
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
|
786
|
+
tool_info=tool_info,
|
787
|
+
tool_output=tool_output_str,
|
788
|
+
tool_utils=T.UTILITIES_DOCSTRING,
|
789
|
+
working_memory=working_memory,
|
790
|
+
coder=self.coder,
|
791
|
+
tester=self.tester,
|
792
|
+
debugger=self.debugger,
|
793
|
+
code_interpreter=code_interpreter,
|
794
|
+
log_progress=self.log_progress,
|
795
|
+
verbosity=self.verbosity,
|
796
|
+
media=media_list,
|
797
|
+
)
|
798
|
+
success = cast(bool, results["success"])
|
799
|
+
code = cast(str, results["code"])
|
800
|
+
test = cast(str, results["test"])
|
801
|
+
working_memory.extend(results["working_memory"]) # type: ignore
|
802
|
+
plan.append({"code": code, "test": test, "plan": plan_i})
|
803
|
+
|
804
|
+
execution_result = cast(Execution, results["test_result"])
|
805
|
+
self.log_progress(
|
806
|
+
{
|
807
|
+
"type": "final_code",
|
808
|
+
"status": "completed" if success else "failed",
|
809
|
+
"payload": {
|
810
|
+
"code": DefaultImports.prepend_imports(code),
|
811
|
+
"test": test,
|
812
|
+
"result": execution_result.to_json(),
|
813
|
+
},
|
814
|
+
}
|
815
|
+
)
|
816
|
+
|
817
|
+
if display_visualization:
|
818
|
+
for res in execution_result.results:
|
819
|
+
if res.png:
|
820
|
+
b64_to_pil(res.png).show()
|
821
|
+
if res.mp4:
|
822
|
+
play_video(res.mp4)
|
823
|
+
|
824
|
+
return {
|
825
|
+
"code": DefaultImports.prepend_imports(code),
|
826
|
+
"test": test,
|
827
|
+
"test_result": execution_result,
|
828
|
+
"plan": plan,
|
829
|
+
"working_memory": working_memory,
|
830
|
+
}
|
831
|
+
|
832
|
+
def log_progress(self, data: Dict[str, Any]) -> None:
|
833
|
+
if self.report_progress_callback is not None:
|
834
|
+
self.report_progress_callback(data)
|
835
|
+
|
836
|
+
|
837
|
+
class AzureVisionAgentCoder(VisionAgentCoder):
|
838
|
+
"""VisionAgentCoder that uses Azure OpenAI APIs for planning, coding, testing.
|
839
|
+
|
840
|
+
Pre-requisites:
|
841
|
+
1. Set the environment variable AZURE_OPENAI_API_KEY to your Azure OpenAI API key.
|
842
|
+
2. Set the environment variable AZURE_OPENAI_ENDPOINT to your Azure OpenAI endpoint.
|
843
|
+
|
844
|
+
Example
|
845
|
+
-------
|
846
|
+
>>> from vision_agent import AzureVisionAgentCoder
|
847
|
+
>>> agent = AzureVisionAgentCoder()
|
848
|
+
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
|
849
|
+
"""
|
850
|
+
|
851
|
+
def __init__(
|
852
|
+
self,
|
853
|
+
planner: Optional[LMM] = None,
|
854
|
+
coder: Optional[LMM] = None,
|
855
|
+
tester: Optional[LMM] = None,
|
856
|
+
debugger: Optional[LMM] = None,
|
857
|
+
tool_recommender: Optional[Sim] = None,
|
858
|
+
verbosity: int = 0,
|
859
|
+
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
860
|
+
) -> None:
|
861
|
+
"""Initialize the Vision Agent Coder.
|
862
|
+
|
863
|
+
Parameters:
|
864
|
+
planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
|
865
|
+
coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
|
866
|
+
tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
|
867
|
+
debugger (Optional[LMM]): The debugger model to
|
868
|
+
tool_recommender (Optional[Sim]): The tool recommender model to use.
|
869
|
+
verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
|
870
|
+
highest verbosity level which will output all intermediate debugging
|
871
|
+
code.
|
872
|
+
report_progress_callback: a callback to report the progress of the agent.
|
873
|
+
This is useful for streaming logs in a web application where multiple
|
874
|
+
VisionAgentCoder instances are running in parallel. This callback
|
875
|
+
ensures that the progress are not mixed up.
|
876
|
+
"""
|
877
|
+
super().__init__(
|
878
|
+
planner=(
|
879
|
+
AzureOpenAILMM(temperature=0.0, json_mode=True)
|
880
|
+
if planner is None
|
881
|
+
else planner
|
882
|
+
),
|
883
|
+
coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
|
884
|
+
tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
|
885
|
+
debugger=(
|
886
|
+
AzureOpenAILMM(temperature=0.0, json_mode=True)
|
887
|
+
if debugger is None
|
888
|
+
else debugger
|
889
|
+
),
|
890
|
+
tool_recommender=(
|
891
|
+
AzureSim(T.TOOLS_DF, sim_key="desc")
|
892
|
+
if tool_recommender is None
|
893
|
+
else tool_recommender
|
894
|
+
),
|
895
|
+
verbosity=verbosity,
|
896
|
+
report_progress_callback=report_progress_callback,
|
897
|
+
)
|