vision-agent 0.2.90__py3-none-any.whl → 0.2.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,897 @@
1
+ import copy
2
+ import difflib
3
+ import logging
4
+ import os
5
+ import sys
6
+ import tempfile
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
9
+
10
+ from PIL import Image
11
+ from rich.console import Console
12
+ from rich.style import Style
13
+ from rich.syntax import Syntax
14
+ from tabulate import tabulate
15
+
16
+ import vision_agent.tools as T
17
+ from vision_agent.agent import Agent
18
+ from vision_agent.agent.agent_utils import extract_code, extract_json
19
+ from vision_agent.agent.vision_agent_coder_prompts import (
20
+ CODE,
21
+ FIX_BUG,
22
+ FULL_TASK,
23
+ PICK_PLAN,
24
+ PLAN,
25
+ PREVIOUS_FAILED,
26
+ SIMPLE_TEST,
27
+ TEST_PLANS,
28
+ USER_REQ,
29
+ )
30
+ from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
31
+ from vision_agent.utils import CodeInterpreterFactory, Execution
32
+ from vision_agent.utils.execute import CodeInterpreter
33
+ from vision_agent.utils.image_utils import b64_to_pil
34
+ from vision_agent.utils.sim import AzureSim, Sim
35
+ from vision_agent.utils.video import play_video
36
+
37
+ logging.basicConfig(stream=sys.stdout)
38
+ WORKSPACE = Path(os.getenv("WORKSPACE", ""))
39
+ _LOGGER = logging.getLogger(__name__)
40
+ _MAX_TABULATE_COL_WIDTH = 80
41
+ _CONSOLE = Console()
42
+
43
+
44
+ class DefaultImports:
45
+ """Container for default imports used in the code execution."""
46
+
47
+ common_imports = [
48
+ "from typing import *",
49
+ "from pillow_heif import register_heif_opener",
50
+ "register_heif_opener()",
51
+ ]
52
+
53
+ @staticmethod
54
+ def to_code_string() -> str:
55
+ return "\n".join(DefaultImports.common_imports + T.__new_tools__)
56
+
57
+ @staticmethod
58
+ def prepend_imports(code: str) -> str:
59
+ """Run this method to prepend the default imports to the code.
60
+ NOTE: be sure to run this method after the custom tools have been registered.
61
+ """
62
+ return DefaultImports.to_code_string() + "\n\n" + code
63
+
64
+
65
+ def get_diff(before: str, after: str) -> str:
66
+ return "".join(
67
+ difflib.unified_diff(
68
+ before.splitlines(keepends=True), after.splitlines(keepends=True)
69
+ )
70
+ )
71
+
72
+
73
+ def format_memory(memory: List[Dict[str, str]]) -> str:
74
+ output_str = ""
75
+ for i, m in enumerate(memory):
76
+ output_str += f"### Feedback {i}:\n"
77
+ output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
78
+ output_str += f"Feedback {i}: {m['feedback']}\n\n"
79
+ if "edits" in m:
80
+ output_str += f"Edits {i}:\n{m['edits']}\n"
81
+ output_str += "\n"
82
+
83
+ return output_str
84
+
85
+
86
+ def format_plans(plans: Dict[str, Any]) -> str:
87
+ plan_str = ""
88
+ for k, v in plans.items():
89
+ plan_str += f"{k}:\n"
90
+ plan_str += "-" + "\n-".join([e["instructions"] for e in v])
91
+
92
+ return plan_str
93
+
94
+
95
+ def extract_image(
96
+ media: Optional[Sequence[Union[str, Path]]]
97
+ ) -> Optional[Sequence[Union[str, Path]]]:
98
+ if media is None:
99
+ return None
100
+
101
+ new_media = []
102
+ for m in media:
103
+ m = Path(m)
104
+ extension = m.suffix
105
+ if extension in [".jpg", ".jpeg", ".png", ".bmp"]:
106
+ new_media.append(m)
107
+ elif extension in [".mp4", ".mov"]:
108
+ frames = T.extract_frames(m)
109
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
110
+ if len(frames) > 0:
111
+ Image.fromarray(frames[0][0]).save(tmp.name)
112
+ new_media.append(Path(tmp.name))
113
+ if len(new_media) == 0:
114
+ return None
115
+ return new_media
116
+
117
+
118
+ def write_plans(
119
+ chat: List[Message],
120
+ tool_desc: str,
121
+ working_memory: str,
122
+ model: LMM,
123
+ ) -> Dict[str, Any]:
124
+ chat = copy.deepcopy(chat)
125
+ if chat[-1]["role"] != "user":
126
+ raise ValueError("Last chat message must be from the user.")
127
+
128
+ user_request = chat[-1]["content"]
129
+ context = USER_REQ.format(user_request=user_request)
130
+ prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
131
+ chat[-1]["content"] = prompt
132
+ return extract_json(model.chat(chat))
133
+
134
+
135
+ def pick_plan(
136
+ chat: List[Message],
137
+ plans: Dict[str, Any],
138
+ tool_info: str,
139
+ model: LMM,
140
+ code_interpreter: CodeInterpreter,
141
+ media: List[str],
142
+ log_progress: Callable[[Dict[str, Any]], None],
143
+ verbosity: int = 0,
144
+ max_retries: int = 3,
145
+ ) -> Tuple[str, str]:
146
+ log_progress(
147
+ {
148
+ "type": "log",
149
+ "log_content": "Generating code to pick the best plan",
150
+ "status": "started",
151
+ }
152
+ )
153
+
154
+ chat = copy.deepcopy(chat)
155
+ if chat[-1]["role"] != "user":
156
+ raise ValueError("Last chat message must be from the user.")
157
+
158
+ plan_str = format_plans(plans)
159
+ prompt = TEST_PLANS.format(
160
+ docstring=tool_info, plans=plan_str, previous_attempts="", media=media
161
+ )
162
+
163
+ code = extract_code(model(prompt))
164
+ log_progress(
165
+ {
166
+ "type": "log",
167
+ "log_content": "Executing code to test plans",
168
+ "code": DefaultImports.prepend_imports(code),
169
+ "status": "running",
170
+ }
171
+ )
172
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
173
+ tool_output_str = ""
174
+ if len(tool_output.logs.stdout) > 0:
175
+ tool_output_str = tool_output.logs.stdout[0]
176
+
177
+ if verbosity == 2:
178
+ _print_code("Initial code and tests:", code)
179
+ _LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")
180
+
181
+ log_progress(
182
+ {
183
+ "type": "log",
184
+ "log_content": (
185
+ "Code execution succeeded"
186
+ if tool_output.success
187
+ else "Code execution failed"
188
+ ),
189
+ "payload": tool_output.to_json(),
190
+ "status": "completed" if tool_output.success else "failed",
191
+ }
192
+ )
193
+
194
+ # retry if the tool output is empty or code fails
195
+ count = 0
196
+ while (not tool_output.success or tool_output_str == "") and count < max_retries:
197
+ prompt = TEST_PLANS.format(
198
+ docstring=tool_info,
199
+ plans=plan_str,
200
+ previous_attempts=PREVIOUS_FAILED.format(
201
+ code=code, error=tool_output.text()
202
+ ),
203
+ media=media,
204
+ )
205
+ log_progress(
206
+ {
207
+ "type": "log",
208
+ "log_content": "Retrying code to test plans",
209
+ "status": "running",
210
+ "code": DefaultImports.prepend_imports(code),
211
+ }
212
+ )
213
+ code = extract_code(model(prompt))
214
+ log_progress(
215
+ {
216
+ "type": "log",
217
+ "log_content": (
218
+ "Code execution succeeded"
219
+ if tool_output.success
220
+ else "Code execution failed"
221
+ ),
222
+ "code": DefaultImports.prepend_imports(code),
223
+ "payload": tool_output.to_json(),
224
+ "status": "completed" if tool_output.success else "failed",
225
+ }
226
+ )
227
+ tool_output = code_interpreter.exec_isolation(
228
+ DefaultImports.prepend_imports(code)
229
+ )
230
+ tool_output_str = ""
231
+ if len(tool_output.logs.stdout) > 0:
232
+ tool_output_str = tool_output.logs.stdout[0]
233
+
234
+ if verbosity == 2:
235
+ _print_code("Code and test after attempted fix:", code)
236
+ _LOGGER.info(f"Code execution result after attempte {count}")
237
+
238
+ count += 1
239
+
240
+ if verbosity >= 1:
241
+ _print_code("Final code:", code)
242
+
243
+ user_req = chat[-1]["content"]
244
+ context = USER_REQ.format(user_request=user_req)
245
+ # because the tool picker model gets the image as well, we have to be careful with
246
+ # how much text we send it, so we truncate the tool output to 20,000 characters
247
+ prompt = PICK_PLAN.format(
248
+ context=context,
249
+ plans=format_plans(plans),
250
+ tool_output=tool_output_str[:20_000],
251
+ )
252
+ chat[-1]["content"] = prompt
253
+ best_plan = extract_json(model(chat))
254
+
255
+ if verbosity >= 1:
256
+ _LOGGER.info(f"Best plan:\n{best_plan}")
257
+ log_progress(
258
+ {
259
+ "type": "log",
260
+ "log_content": "Picked best plan",
261
+ "status": "completed",
262
+ "payload": plans[best_plan["best_plan"]],
263
+ }
264
+ )
265
+ return best_plan["best_plan"], tool_output_str
266
+
267
+
268
+ def write_code(
269
+ coder: LMM,
270
+ chat: List[Message],
271
+ plan: str,
272
+ tool_info: str,
273
+ tool_output: str,
274
+ feedback: str,
275
+ ) -> str:
276
+ chat = copy.deepcopy(chat)
277
+ if chat[-1]["role"] != "user":
278
+ raise ValueError("Last chat message must be from the user.")
279
+
280
+ user_request = chat[-1]["content"]
281
+ prompt = CODE.format(
282
+ docstring=tool_info,
283
+ question=FULL_TASK.format(user_request=user_request, subtasks=plan),
284
+ tool_output=tool_output,
285
+ feedback=feedback,
286
+ )
287
+ chat[-1]["content"] = prompt
288
+ return extract_code(coder(chat))
289
+
290
+
291
+ def write_test(
292
+ tester: LMM,
293
+ chat: List[Message],
294
+ tool_utils: str,
295
+ code: str,
296
+ feedback: str,
297
+ media: Optional[Sequence[Union[str, Path]]] = None,
298
+ ) -> str:
299
+ chat = copy.deepcopy(chat)
300
+ if chat[-1]["role"] != "user":
301
+ raise ValueError("Last chat message must be from the user.")
302
+
303
+ user_request = chat[-1]["content"]
304
+ prompt = SIMPLE_TEST.format(
305
+ docstring=tool_utils,
306
+ question=user_request,
307
+ code=code,
308
+ feedback=feedback,
309
+ media=media,
310
+ )
311
+ chat[-1]["content"] = prompt
312
+ return extract_code(tester(chat))
313
+
314
+
315
+ def write_and_test_code(
316
+ chat: List[Message],
317
+ plan: str,
318
+ tool_info: str,
319
+ tool_output: str,
320
+ tool_utils: str,
321
+ working_memory: List[Dict[str, str]],
322
+ coder: LMM,
323
+ tester: LMM,
324
+ debugger: LMM,
325
+ code_interpreter: CodeInterpreter,
326
+ log_progress: Callable[[Dict[str, Any]], None],
327
+ verbosity: int = 0,
328
+ max_retries: int = 3,
329
+ media: Optional[Sequence[Union[str, Path]]] = None,
330
+ ) -> Dict[str, Any]:
331
+ log_progress(
332
+ {
333
+ "type": "log",
334
+ "log_content": "Generating code",
335
+ "status": "started",
336
+ }
337
+ )
338
+ code = write_code(
339
+ coder,
340
+ chat,
341
+ plan,
342
+ tool_info,
343
+ tool_output,
344
+ format_memory(working_memory),
345
+ )
346
+ test = write_test(
347
+ tester, chat, tool_utils, code, format_memory(working_memory), media
348
+ )
349
+
350
+ log_progress(
351
+ {
352
+ "type": "log",
353
+ "log_content": "Running code",
354
+ "status": "running",
355
+ "code": DefaultImports.prepend_imports(code),
356
+ "payload": {
357
+ "test": test,
358
+ },
359
+ }
360
+ )
361
+ result = code_interpreter.exec_isolation(
362
+ f"{DefaultImports.to_code_string()}\n{code}\n{test}"
363
+ )
364
+ log_progress(
365
+ {
366
+ "type": "log",
367
+ "log_content": (
368
+ "Code execution succeeded"
369
+ if result.success
370
+ else "Code execution failed"
371
+ ),
372
+ "status": "completed" if result.success else "failed",
373
+ "code": DefaultImports.prepend_imports(code),
374
+ "payload": {
375
+ "test": test,
376
+ "result": result.to_json(),
377
+ },
378
+ }
379
+ )
380
+ if verbosity == 2:
381
+ _print_code("Initial code and tests:", code, test)
382
+ _LOGGER.info(
383
+ f"Initial code execution result:\n{result.text(include_logs=True)}"
384
+ )
385
+
386
+ count = 0
387
+ new_working_memory: List[Dict[str, str]] = []
388
+ while not result.success and count < max_retries:
389
+ if verbosity == 2:
390
+ _LOGGER.info(f"Start debugging attempt {count + 1}")
391
+ code, test, result = debug_code(
392
+ working_memory,
393
+ debugger,
394
+ code_interpreter,
395
+ code,
396
+ test,
397
+ result,
398
+ new_working_memory,
399
+ log_progress,
400
+ verbosity,
401
+ )
402
+ count += 1
403
+
404
+ if verbosity >= 1:
405
+ _print_code("Final code and tests:", code, test)
406
+
407
+ return {
408
+ "code": code,
409
+ "test": test,
410
+ "success": result.success,
411
+ "test_result": result,
412
+ "working_memory": new_working_memory,
413
+ }
414
+
415
+
416
+ def debug_code(
417
+ working_memory: List[Dict[str, str]],
418
+ debugger: LMM,
419
+ code_interpreter: CodeInterpreter,
420
+ code: str,
421
+ test: str,
422
+ result: Execution,
423
+ new_working_memory: List[Dict[str, str]],
424
+ log_progress: Callable[[Dict[str, Any]], None],
425
+ verbosity: int = 0,
426
+ ) -> tuple[str, str, Execution]:
427
+ log_progress(
428
+ {
429
+ "type": "code",
430
+ "status": "started",
431
+ }
432
+ )
433
+
434
+ fixed_code_and_test = {"code": "", "test": "", "reflections": ""}
435
+ success = False
436
+ count = 0
437
+ while not success and count < 3:
438
+ try:
439
+ fixed_code_and_test = extract_json(
440
+ debugger(
441
+ FIX_BUG.format(
442
+ code=code,
443
+ tests=test,
444
+ result="\n".join(result.text().splitlines()[-50:]),
445
+ feedback=format_memory(working_memory + new_working_memory),
446
+ )
447
+ )
448
+ )
449
+ success = True
450
+ except Exception as e:
451
+ _LOGGER.exception(f"Error while extracting JSON: {e}")
452
+
453
+ count += 1
454
+
455
+ old_code = code
456
+ old_test = test
457
+
458
+ if fixed_code_and_test["code"].strip() != "":
459
+ code = extract_code(fixed_code_and_test["code"])
460
+ if fixed_code_and_test["test"].strip() != "":
461
+ test = extract_code(fixed_code_and_test["test"])
462
+
463
+ new_working_memory.append(
464
+ {
465
+ "code": f"{code}\n{test}",
466
+ "feedback": fixed_code_and_test["reflections"],
467
+ "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
468
+ }
469
+ )
470
+ log_progress(
471
+ {
472
+ "type": "code",
473
+ "status": "running",
474
+ "payload": {
475
+ "code": DefaultImports.prepend_imports(code),
476
+ "test": test,
477
+ },
478
+ }
479
+ )
480
+
481
+ result = code_interpreter.exec_isolation(
482
+ f"{DefaultImports.to_code_string()}\n{code}\n{test}"
483
+ )
484
+ log_progress(
485
+ {
486
+ "type": "code",
487
+ "status": "completed" if result.success else "failed",
488
+ "payload": {
489
+ "code": DefaultImports.prepend_imports(code),
490
+ "test": test,
491
+ "result": result.to_json(),
492
+ },
493
+ }
494
+ )
495
+ if verbosity == 2:
496
+ _print_code("Code and test after attempted fix:", code, test)
497
+ _LOGGER.info(
498
+ f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
499
+ )
500
+
501
+ return code, test, result
502
+
503
+
504
+ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
505
+ _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
506
+ _CONSOLE.print("=" * 30 + " Code " + "=" * 30)
507
+ _CONSOLE.print(
508
+ Syntax(
509
+ DefaultImports.prepend_imports(code),
510
+ "python",
511
+ theme="gruvbox-dark",
512
+ line_numbers=True,
513
+ )
514
+ )
515
+ if test:
516
+ _CONSOLE.print("=" * 30 + " Test " + "=" * 30)
517
+ _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
518
+
519
+
520
+ def retrieve_tools(
521
+ plans: Dict[str, List[Dict[str, str]]],
522
+ tool_recommender: Sim,
523
+ log_progress: Callable[[Dict[str, Any]], None],
524
+ verbosity: int = 0,
525
+ ) -> Dict[str, str]:
526
+ log_progress(
527
+ {
528
+ "type": "tools",
529
+ "status": "started",
530
+ }
531
+ )
532
+ tool_info = []
533
+ tool_desc = []
534
+ tool_lists: Dict[str, List[Dict[str, str]]] = {}
535
+ for k, plan in plans.items():
536
+ tool_lists[k] = []
537
+ for task in plan:
538
+ tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
539
+ tool_info.extend([e["doc"] for e in tools])
540
+ tool_desc.extend([e["desc"] for e in tools])
541
+ tool_lists[k].extend(
542
+ {"description": e["desc"], "documentation": e["doc"]} for e in tools
543
+ )
544
+
545
+ if verbosity == 2:
546
+ tool_desc_str = "\n".join(set(tool_desc))
547
+ _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
548
+
549
+ tool_lists_unique = {}
550
+ for k in tool_lists:
551
+ tool_lists_unique[k] = "\n\n".join(
552
+ set(e["documentation"] for e in tool_lists[k])
553
+ )
554
+ all_tools = "\n\n".join(set(tool_info))
555
+ tool_lists_unique["all"] = all_tools
556
+ return tool_lists_unique
557
+
558
+
559
+ class VisionAgentCoder(Agent):
560
+ """Vision Agent Coder is an agentic framework that can output code based on a user
561
+ request. It can plan tasks, retrieve relevant tools, write code, write tests and
562
+ reflect on failed test cases to debug code. It is inspired by AgentCoder
563
+ https://arxiv.org/abs/2312.13010 and Data Interpeter https://arxiv.org/abs/2402.18679
564
+
565
+ Example
566
+ -------
567
+ >>> from vision_agent.agent import VisionAgentCoder
568
+ >>> agent = VisionAgentCoder()
569
+ >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
570
+ """
571
+
572
+ def __init__(
573
+ self,
574
+ planner: Optional[LMM] = None,
575
+ coder: Optional[LMM] = None,
576
+ tester: Optional[LMM] = None,
577
+ debugger: Optional[LMM] = None,
578
+ tool_recommender: Optional[Sim] = None,
579
+ verbosity: int = 0,
580
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
581
+ code_sandbox_runtime: Optional[str] = None,
582
+ ) -> None:
583
+ """Initialize the Vision Agent Coder.
584
+
585
+ Parameters:
586
+ planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
587
+ coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
588
+ tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
589
+ debugger (Optional[LMM]): The debugger model to
590
+ tool_recommender (Optional[Sim]): The tool recommender model to use.
591
+ verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
592
+ highest verbosity level which will output all intermediate debugging
593
+ code.
594
+ report_progress_callback: a callback to report the progress of the agent.
595
+ This is useful for streaming logs in a web application where multiple
596
+ VisionAgentCoder instances are running in parallel. This callback
597
+ ensures that the progress are not mixed up.
598
+ code_sandbox_runtime: the code sandbox runtime to use. A code sandbox is
599
+ used to run the generated code. It can be one of the following
600
+ values: None, "local" or "e2b". If None, VisionAgentCoder will read
601
+ the value from the environment variable CODE_SANDBOX_RUNTIME. If it's
602
+ also None, the local python runtime environment will be used.
603
+ """
604
+
605
+ self.planner = (
606
+ OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
607
+ )
608
+ self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
609
+ self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
610
+ self.debugger = (
611
+ OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
612
+ )
613
+ self.verbosity = verbosity
614
+ if self.verbosity > 0:
615
+ _LOGGER.setLevel(logging.INFO)
616
+
617
+ self.tool_recommender = (
618
+ Sim(T.TOOLS_DF, sim_key="desc")
619
+ if tool_recommender is None
620
+ else tool_recommender
621
+ )
622
+ self.report_progress_callback = report_progress_callback
623
+ self.code_sandbox_runtime = code_sandbox_runtime
624
+
625
+ def __call__(
626
+ self,
627
+ input: Union[str, List[Message]],
628
+ media: Optional[Union[str, Path]] = None,
629
+ ) -> str:
630
+ """Chat with VisionAgentCoder and return intermediate information regarding the
631
+ task.
632
+
633
+ Parameters:
634
+ input (Union[str, List[Message]]): A conversation in the format of
635
+ [{"role": "user", "content": "describe your task here..."}] or a string
636
+ of just the contents.
637
+ media (Optional[Union[str, Path]]): The media file to be used in the task.
638
+
639
+ Returns:
640
+ str: The code output by the VisionAgentCoder.
641
+ """
642
+
643
+ if isinstance(input, str):
644
+ input = [{"role": "user", "content": input}]
645
+ if media is not None:
646
+ input[0]["media"] = [media]
647
+ results = self.chat_with_workflow(input)
648
+ results.pop("working_memory")
649
+ return results["code"] # type: ignore
650
+
651
+ def chat_with_workflow(
652
+ self,
653
+ chat: List[Message],
654
+ test_multi_plan: bool = True,
655
+ display_visualization: bool = False,
656
+ ) -> Dict[str, Any]:
657
+ """Chat with VisionAgentCoder and return intermediate information regarding the
658
+ task.
659
+
660
+ Parameters:
661
+ chat (List[Message]): A conversation
662
+ in the format of:
663
+ [{"role": "user", "content": "describe your task here..."}]
664
+ or if it contains media files, it should be in the format of:
665
+ [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
666
+ test_multi_plan (bool): If True, it will test tools for multiple plans and
667
+ pick the best one based off of the tool results. If False, it will go
668
+ with the first plan.
669
+ display_visualization (bool): If True, it opens a new window locally to
670
+ show the image(s) created by visualization code (if there is any).
671
+
672
+ Returns:
673
+ Dict[str, Any]: A dictionary containing the code, test, test result, plan,
674
+ and working memory of the agent.
675
+ """
676
+
677
+ if not chat:
678
+ raise ValueError("Chat cannot be empty.")
679
+
680
+ # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
681
+ with CodeInterpreterFactory.new_instance(
682
+ code_sandbox_runtime=self.code_sandbox_runtime
683
+ ) as code_interpreter:
684
+ chat = copy.deepcopy(chat)
685
+ media_list = []
686
+ for chat_i in chat:
687
+ if "media" in chat_i:
688
+ for media in chat_i["media"]:
689
+ media = code_interpreter.upload_file(media)
690
+ chat_i["content"] += f" Media name {media}" # type: ignore
691
+ media_list.append(media)
692
+
693
+ int_chat = cast(
694
+ List[Message],
695
+ [
696
+ (
697
+ {
698
+ "role": c["role"],
699
+ "content": c["content"],
700
+ "media": c["media"],
701
+ }
702
+ if "media" in c
703
+ else {"role": c["role"], "content": c["content"]}
704
+ )
705
+ for c in chat
706
+ ],
707
+ )
708
+
709
+ code = ""
710
+ test = ""
711
+ working_memory: List[Dict[str, str]] = []
712
+ results = {"code": "", "test": "", "plan": []}
713
+ plan = []
714
+ success = False
715
+ self.log_progress(
716
+ {
717
+ "type": "log",
718
+ "log_content": "Creating plans",
719
+ "status": "started",
720
+ }
721
+ )
722
+ plans = write_plans(
723
+ int_chat,
724
+ T.TOOL_DESCRIPTIONS,
725
+ format_memory(working_memory),
726
+ self.planner,
727
+ )
728
+
729
+ if self.verbosity >= 1:
730
+ for p in plans:
731
+ _LOGGER.info(
732
+ f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
733
+ )
734
+
735
+ tool_infos = retrieve_tools(
736
+ plans,
737
+ self.tool_recommender,
738
+ self.log_progress,
739
+ self.verbosity,
740
+ )
741
+
742
+ if test_multi_plan:
743
+ best_plan, tool_output_str = pick_plan(
744
+ int_chat,
745
+ plans,
746
+ tool_infos["all"],
747
+ self.coder,
748
+ code_interpreter,
749
+ media_list,
750
+ self.log_progress,
751
+ verbosity=self.verbosity,
752
+ )
753
+ else:
754
+ best_plan = list(plans.keys())[0]
755
+ tool_output_str = ""
756
+
757
+ if best_plan in plans and best_plan in tool_infos:
758
+ plan_i = plans[best_plan]
759
+ tool_info = tool_infos[best_plan]
760
+ else:
761
+ if self.verbosity >= 1:
762
+ _LOGGER.warning(
763
+ f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
764
+ )
765
+ k = list(plans.keys())[0]
766
+ plan_i = plans[k]
767
+ tool_info = tool_infos[k]
768
+
769
+ self.log_progress(
770
+ {
771
+ "type": "log",
772
+ "log_content": "Creating plans",
773
+ "status": "completed",
774
+ "payload": tool_info,
775
+ }
776
+ )
777
+
778
+ if self.verbosity >= 1:
779
+ _LOGGER.info(
780
+ f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
781
+ )
782
+
783
+ results = write_and_test_code(
784
+ chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
785
+ plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
786
+ tool_info=tool_info,
787
+ tool_output=tool_output_str,
788
+ tool_utils=T.UTILITIES_DOCSTRING,
789
+ working_memory=working_memory,
790
+ coder=self.coder,
791
+ tester=self.tester,
792
+ debugger=self.debugger,
793
+ code_interpreter=code_interpreter,
794
+ log_progress=self.log_progress,
795
+ verbosity=self.verbosity,
796
+ media=media_list,
797
+ )
798
+ success = cast(bool, results["success"])
799
+ code = cast(str, results["code"])
800
+ test = cast(str, results["test"])
801
+ working_memory.extend(results["working_memory"]) # type: ignore
802
+ plan.append({"code": code, "test": test, "plan": plan_i})
803
+
804
+ execution_result = cast(Execution, results["test_result"])
805
+ self.log_progress(
806
+ {
807
+ "type": "final_code",
808
+ "status": "completed" if success else "failed",
809
+ "payload": {
810
+ "code": DefaultImports.prepend_imports(code),
811
+ "test": test,
812
+ "result": execution_result.to_json(),
813
+ },
814
+ }
815
+ )
816
+
817
+ if display_visualization:
818
+ for res in execution_result.results:
819
+ if res.png:
820
+ b64_to_pil(res.png).show()
821
+ if res.mp4:
822
+ play_video(res.mp4)
823
+
824
+ return {
825
+ "code": DefaultImports.prepend_imports(code),
826
+ "test": test,
827
+ "test_result": execution_result,
828
+ "plan": plan,
829
+ "working_memory": working_memory,
830
+ }
831
+
832
+ def log_progress(self, data: Dict[str, Any]) -> None:
833
+ if self.report_progress_callback is not None:
834
+ self.report_progress_callback(data)
835
+
836
+
837
+ class AzureVisionAgentCoder(VisionAgentCoder):
838
+ """VisionAgentCoder that uses Azure OpenAI APIs for planning, coding, testing.
839
+
840
+ Pre-requisites:
841
+ 1. Set the environment variable AZURE_OPENAI_API_KEY to your Azure OpenAI API key.
842
+ 2. Set the environment variable AZURE_OPENAI_ENDPOINT to your Azure OpenAI endpoint.
843
+
844
+ Example
845
+ -------
846
+ >>> from vision_agent import AzureVisionAgentCoder
847
+ >>> agent = AzureVisionAgentCoder()
848
+ >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
849
+ """
850
+
851
+ def __init__(
852
+ self,
853
+ planner: Optional[LMM] = None,
854
+ coder: Optional[LMM] = None,
855
+ tester: Optional[LMM] = None,
856
+ debugger: Optional[LMM] = None,
857
+ tool_recommender: Optional[Sim] = None,
858
+ verbosity: int = 0,
859
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
860
+ ) -> None:
861
+ """Initialize the Vision Agent Coder.
862
+
863
+ Parameters:
864
+ planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
865
+ coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
866
+ tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
867
+ debugger (Optional[LMM]): The debugger model to
868
+ tool_recommender (Optional[Sim]): The tool recommender model to use.
869
+ verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
870
+ highest verbosity level which will output all intermediate debugging
871
+ code.
872
+ report_progress_callback: a callback to report the progress of the agent.
873
+ This is useful for streaming logs in a web application where multiple
874
+ VisionAgentCoder instances are running in parallel. This callback
875
+ ensures that the progress are not mixed up.
876
+ """
877
+ super().__init__(
878
+ planner=(
879
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
880
+ if planner is None
881
+ else planner
882
+ ),
883
+ coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
884
+ tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
885
+ debugger=(
886
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
887
+ if debugger is None
888
+ else debugger
889
+ ),
890
+ tool_recommender=(
891
+ AzureSim(T.TOOLS_DF, sim_key="desc")
892
+ if tool_recommender is None
893
+ else tool_recommender
894
+ ),
895
+ verbosity=verbosity,
896
+ report_progress_callback=report_progress_callback,
897
+ )