vision-agent 0.2.91__py3-none-any.whl → 0.2.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,57 +1,39 @@
1
1
  import copy
2
- import difflib
3
- import json
4
2
  import logging
5
- import sys
6
- import tempfile
3
+ import os
7
4
  from pathlib import Path
8
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
5
+ from typing import Any, Dict, List, Optional, Union, cast
9
6
 
10
- from langsmith import traceable
11
- from PIL import Image
12
- from rich.console import Console
13
- from rich.style import Style
14
- from rich.syntax import Syntax
15
- from tabulate import tabulate
16
-
17
- import vision_agent.tools as T
18
7
  from vision_agent.agent import Agent
8
+ from vision_agent.agent.agent_utils import extract_json
19
9
  from vision_agent.agent.vision_agent_prompts import (
20
- CODE,
21
- FIX_BUG,
22
- FULL_TASK,
23
- PICK_PLAN,
24
- PLAN,
25
- PREVIOUS_FAILED,
26
- SIMPLE_TEST,
27
- TEST_PLANS,
28
- USER_REQ,
10
+ EXAMPLES_CODE1,
11
+ EXAMPLES_CODE2,
12
+ VA_CODE,
29
13
  )
30
- from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
31
- from vision_agent.utils import CodeInterpreterFactory, Execution
14
+ from vision_agent.lmm import LMM, Message, OpenAILMM
15
+ from vision_agent.tools import META_TOOL_DOCSTRING
16
+ from vision_agent.utils import CodeInterpreterFactory
32
17
  from vision_agent.utils.execute import CodeInterpreter
33
- from vision_agent.utils.image_utils import b64_to_pil
34
- from vision_agent.utils.sim import AzureSim, Sim
35
- from vision_agent.utils.video import play_video
36
18
 
37
- logging.basicConfig(stream=sys.stdout)
19
+ logging.basicConfig(level=logging.INFO)
38
20
  _LOGGER = logging.getLogger(__name__)
39
- _MAX_TABULATE_COL_WIDTH = 80
40
- _CONSOLE = Console()
21
+ WORKSPACE = Path(os.getenv("WORKSPACE", ""))
22
+ WORKSPACE.mkdir(parents=True, exist_ok=True)
23
+ if str(WORKSPACE) != "":
24
+ os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}"
41
25
 
42
26
 
43
27
  class DefaultImports:
44
- """Container for default imports used in the code execution."""
45
-
46
- common_imports = [
28
+ code = [
47
29
  "from typing import *",
48
- "from pillow_heif import register_heif_opener",
49
- "register_heif_opener()",
30
+ "from vision_agent.utils.execute import CodeInterpreter",
31
+ "from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions",
50
32
  ]
51
33
 
52
34
  @staticmethod
53
35
  def to_code_string() -> str:
54
- return "\n".join(DefaultImports.common_imports + T.__new_tools__)
36
+ return "\n".join(DefaultImports.code)
55
37
 
56
38
  @staticmethod
57
39
  def prepend_imports(code: str) -> str:
@@ -61,686 +43,143 @@ class DefaultImports:
61
43
  return DefaultImports.to_code_string() + "\n\n" + code
62
44
 
63
45
 
64
- def get_diff(before: str, after: str) -> str:
65
- return "".join(
66
- difflib.unified_diff(
67
- before.splitlines(keepends=True), after.splitlines(keepends=True)
68
- )
69
- )
70
-
71
-
72
- def format_memory(memory: List[Dict[str, str]]) -> str:
73
- output_str = ""
74
- for i, m in enumerate(memory):
75
- output_str += f"### Feedback {i}:\n"
76
- output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
77
- output_str += f"Feedback {i}: {m['feedback']}\n\n"
78
- if "edits" in m:
79
- output_str += f"Edits {i}:\n{m['edits']}\n"
80
- output_str += "\n"
81
-
82
- return output_str
83
-
84
-
85
- def format_plans(plans: Dict[str, Any]) -> str:
86
- plan_str = ""
87
- for k, v in plans.items():
88
- plan_str += f"{k}:\n"
89
- plan_str += "-" + "\n-".join([e["instructions"] for e in v])
90
-
91
- return plan_str
92
-
93
-
94
- def extract_code(code: str) -> str:
95
- if "\n```python" in code:
96
- start = "\n```python"
97
- elif "```python" in code:
98
- start = "```python"
99
- else:
100
- return code
101
-
102
- code = code[code.find(start) + len(start) :]
103
- code = code[: code.find("```")]
104
- if code.startswith("python\n"):
105
- code = code[len("python\n") :]
106
- return code
107
-
108
-
109
- def extract_json(json_str: str) -> Dict[str, Any]:
110
- try:
111
- json_dict = json.loads(json_str)
112
- except json.JSONDecodeError:
113
- input_json_str = json_str
114
- if "```json" in json_str:
115
- json_str = json_str[json_str.find("```json") + len("```json") :]
116
- json_str = json_str[: json_str.find("```")]
117
- elif "```" in json_str:
118
- json_str = json_str[json_str.find("```") + len("```") :]
119
- # get the last ``` not one from an intermediate string
120
- json_str = json_str[: json_str.find("}```")]
121
- try:
122
- json_dict = json.loads(json_str)
123
- except json.JSONDecodeError as e:
124
- error_msg = f"Could not extract JSON from the given str: {json_str}.\nFunction input:\n{input_json_str}"
125
- _LOGGER.exception(error_msg)
126
- raise ValueError(error_msg) from e
127
- return json_dict # type: ignore
128
-
129
-
130
- def extract_image(
131
- media: Optional[Sequence[Union[str, Path]]]
132
- ) -> Optional[Sequence[Union[str, Path]]]:
133
- if media is None:
134
- return None
135
-
136
- new_media = []
137
- for m in media:
138
- m = Path(m)
139
- extension = m.suffix
140
- if extension in [".jpg", ".jpeg", ".png", ".bmp"]:
141
- new_media.append(m)
142
- elif extension in [".mp4", ".mov"]:
143
- frames = T.extract_frames(m)
144
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
145
- if len(frames) > 0:
146
- Image.fromarray(frames[0][0]).save(tmp.name)
147
- new_media.append(Path(tmp.name))
148
- if len(new_media) == 0:
149
- return None
150
- return new_media
151
-
152
-
153
- @traceable
154
- def write_plans(
155
- chat: List[Message],
156
- tool_desc: str,
157
- working_memory: str,
158
- model: LMM,
159
- ) -> Dict[str, Any]:
160
- chat = copy.deepcopy(chat)
161
- if chat[-1]["role"] != "user":
162
- raise ValueError("Last chat message must be from the user.")
163
-
164
- user_request = chat[-1]["content"]
165
- context = USER_REQ.format(user_request=user_request)
166
- prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
167
- chat[-1]["content"] = prompt
168
- return extract_json(model.chat(chat))
169
-
170
-
171
- @traceable
172
- def pick_plan(
173
- chat: List[Message],
174
- plans: Dict[str, Any],
175
- tool_infos: Dict[str, str],
176
- model: LMM,
177
- code_interpreter: CodeInterpreter,
178
- test_multi_plan: bool,
179
- log_progress: Callable[[Dict[str, Any]], None],
180
- verbosity: int = 0,
181
- max_retries: int = 3,
182
- ) -> Tuple[Any, str, str]:
183
- if not test_multi_plan:
184
- k = list(plans.keys())[0]
185
- log_progress(
186
- {
187
- "type": "log",
188
- "log_content": "Plans created",
189
- "status": "completed",
190
- "payload": plans[k],
191
- }
192
- )
193
- return plans[k], tool_infos[k], ""
194
-
195
- log_progress(
196
- {
197
- "type": "log",
198
- "log_content": "Generating code to pick best plan",
199
- "status": "started",
200
- }
201
- )
202
- all_tool_info = tool_infos["all"]
46
+ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
203
47
  chat = copy.deepcopy(chat)
204
- if chat[-1]["role"] != "user":
205
- raise ValueError("Last chat message must be from the user.")
206
-
207
- plan_str = format_plans(plans)
208
- prompt = TEST_PLANS.format(
209
- docstring=all_tool_info, plans=plan_str, previous_attempts=""
210
- )
211
48
 
212
- code = extract_code(model(prompt))
213
- log_progress(
214
- {
215
- "type": "log",
216
- "log_content": "Executing code to test plan",
217
- "code": code,
218
- "status": "running",
219
- }
220
- )
221
- tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
222
- tool_output_str = ""
223
- if len(tool_output.logs.stdout) > 0:
224
- tool_output_str = tool_output.logs.stdout[0]
225
-
226
- if verbosity == 2:
227
- _print_code("Initial code and tests:", code)
228
- _LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")
229
-
230
- log_progress(
231
- {
232
- "type": "log",
233
- "log_content": (
234
- "Code execution succeed"
235
- if tool_output.success
236
- else "Code execution failed"
237
- ),
238
- "payload": tool_output.to_json(),
239
- "status": "completed" if tool_output.success else "failed",
240
- }
241
- )
242
- # retry if the tool output is empty or code fails
243
- count = 0
244
- while (not tool_output.success or tool_output_str == "") and count < max_retries:
245
- prompt = TEST_PLANS.format(
246
- docstring=all_tool_info,
247
- plans=plan_str,
248
- previous_attempts=PREVIOUS_FAILED.format(
249
- code=code, error=tool_output.text()
250
- ),
251
- )
252
- log_progress(
253
- {
254
- "type": "log",
255
- "log_content": "Retry running code",
256
- "code": code,
257
- "status": "running",
258
- }
259
- )
260
- code = extract_code(model(prompt))
261
- tool_output = code_interpreter.exec_isolation(
262
- DefaultImports.prepend_imports(code)
263
- )
264
- log_progress(
265
- {
266
- "type": "log",
267
- "log_content": (
268
- "Code execution succeed"
269
- if tool_output.success
270
- else "Code execution failed"
271
- ),
272
- "code": code,
273
- "payload": {
274
- "result": tool_output.to_json(),
275
- },
276
- "status": "completed" if tool_output.success else "failed",
277
- }
278
- )
279
- tool_output_str = ""
280
- if len(tool_output.logs.stdout) > 0:
281
- tool_output_str = tool_output.logs.stdout[0]
282
-
283
- if verbosity == 2:
284
- _print_code("Code and test after attempted fix:", code)
285
- _LOGGER.info(f"Code execution result after attempte {count}")
286
-
287
- count += 1
288
-
289
- if verbosity >= 1:
290
- _print_code("Final code:", code)
291
-
292
- user_req = chat[-1]["content"]
293
- context = USER_REQ.format(user_request=user_req)
294
- # because the tool picker model gets the image as well, we have to be careful with
295
- # how much text we send it, so we truncate the tool output to 20,000 characters
296
- prompt = PICK_PLAN.format(
297
- context=context,
298
- plans=format_plans(plans),
299
- tool_output=tool_output_str[:20_000],
300
- )
301
- chat[-1]["content"] = prompt
302
- best_plan = extract_json(model(chat))
303
- if verbosity >= 1:
304
- _LOGGER.info(f"Best plan:\n{best_plan}")
305
-
306
- plan = best_plan["best_plan"]
307
- if plan in plans and plan in tool_infos:
308
- best_plans = plans[plan]
309
- best_tool_infos = tool_infos[plan]
49
+ conversation = ""
50
+ for chat_i in chat:
51
+ if chat_i["role"] == "user":
52
+ conversation += f"USER: {chat_i['content']}\n\n"
53
+ elif chat_i["role"] == "observation":
54
+ conversation += f"OBSERVATION:\n{chat_i['content']}\n\n"
55
+ elif chat_i["role"] == "assistant":
56
+ conversation += f"AGENT: {chat_i['content']}\n\n"
57
+ else:
58
+ raise ValueError(f"role {chat_i['role']} is not supported")
59
+
60
+ prompt = VA_CODE.format(
61
+ documentation=META_TOOL_DOCSTRING,
62
+ examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}",
63
+ dir=WORKSPACE,
64
+ conversation=conversation,
65
+ )
66
+ return extract_json(orch([{"role": "user", "content": prompt}]))
67
+
68
+
69
+ def run_code_action(code: str, code_interpreter: CodeInterpreter) -> str:
70
+ # Note the code interpreter needs to keep running in the same environment because
71
+ # the SWE tools hold state like line numbers and currently open files.
72
+ result = code_interpreter.exec_cell(DefaultImports.prepend_imports(code))
73
+
74
+ return_str = ""
75
+ if result.success:
76
+ for res in result.results:
77
+ if res.text is not None:
78
+ return_str += res.text.replace("\\n", "\n")
79
+ if result.logs.stdout:
80
+ return_str += "----- stdout -----\n"
81
+ for log in result.logs.stdout:
82
+ return_str += log.replace("\\n", "\n")
310
83
  else:
311
- if verbosity >= 1:
312
- _LOGGER.warning(
313
- f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
84
+ # for log in result.logs.stderr:
85
+ # return_str += log.replace("\\n", "\n")
86
+ if result.error:
87
+ return_str += (
88
+ "\n" + result.error.value + "\n".join(result.error.traceback_raw)
314
89
  )
315
- k = list(plans.keys())[0]
316
- best_plans = plans[k]
317
- best_tool_infos = tool_infos[k]
318
-
319
- log_progress(
320
- {
321
- "type": "log",
322
- "log_content": "Picked best plan",
323
- "status": "complete",
324
- "payload": best_plans,
325
- }
326
- )
327
- return best_plans, best_tool_infos, tool_output_str
328
-
329
-
330
- @traceable
331
- def write_code(
332
- coder: LMM,
333
- chat: List[Message],
334
- plan: str,
335
- tool_info: str,
336
- tool_output: str,
337
- feedback: str,
338
- ) -> str:
339
- chat = copy.deepcopy(chat)
340
- if chat[-1]["role"] != "user":
341
- raise ValueError("Last chat message must be from the user.")
342
-
343
- user_request = chat[-1]["content"]
344
- prompt = CODE.format(
345
- docstring=tool_info,
346
- question=FULL_TASK.format(user_request=user_request, subtasks=plan),
347
- tool_output=tool_output,
348
- feedback=feedback,
349
- )
350
- chat[-1]["content"] = prompt
351
- return extract_code(coder(chat))
352
-
353
-
354
- @traceable
355
- def write_test(
356
- tester: LMM,
357
- chat: List[Message],
358
- tool_utils: str,
359
- code: str,
360
- feedback: str,
361
- media: Optional[Sequence[Union[str, Path]]] = None,
362
- ) -> str:
363
- chat = copy.deepcopy(chat)
364
- if chat[-1]["role"] != "user":
365
- raise ValueError("Last chat message must be from the user.")
366
-
367
- user_request = chat[-1]["content"]
368
- prompt = SIMPLE_TEST.format(
369
- docstring=tool_utils,
370
- question=user_request,
371
- code=code,
372
- feedback=feedback,
373
- media=media,
374
- )
375
- chat[-1]["content"] = prompt
376
- return extract_code(tester(chat))
377
-
378
-
379
- def write_and_test_code(
380
- chat: List[Message],
381
- plan: str,
382
- tool_info: str,
383
- tool_output: str,
384
- tool_utils: str,
385
- working_memory: List[Dict[str, str]],
386
- coder: LMM,
387
- tester: LMM,
388
- debugger: LMM,
389
- code_interpreter: CodeInterpreter,
390
- log_progress: Callable[[Dict[str, Any]], None],
391
- verbosity: int = 0,
392
- max_retries: int = 3,
393
- media: Optional[Sequence[Union[str, Path]]] = None,
394
- ) -> Dict[str, Any]:
395
- log_progress(
396
- {
397
- "type": "log",
398
- "log_content": "Generating code",
399
- "status": "started",
400
- }
401
- )
402
- code = write_code(
403
- coder,
404
- chat,
405
- plan,
406
- tool_info,
407
- tool_output,
408
- format_memory(working_memory),
409
- )
410
- test = write_test(
411
- tester, chat, tool_utils, code, format_memory(working_memory), media
412
- )
413
90
 
414
- log_progress(
415
- {
416
- "type": "log",
417
- "log_content": "Running code",
418
- "status": "running",
419
- "code": DefaultImports.prepend_imports(code),
420
- "payload": {
421
- "test": test,
422
- },
423
- }
424
- )
425
- result = code_interpreter.exec_isolation(
426
- f"{DefaultImports.to_code_string()}\n{code}\n{test}"
427
- )
428
- log_progress(
429
- {
430
- "type": "log",
431
- "log_content": (
432
- "Code execution succeed" if result.success else "Code execution failed"
433
- ),
434
- "status": "completed" if result.success else "failed",
435
- "code": DefaultImports.prepend_imports(code),
436
- "payload": {
437
- "test": test,
438
- "result": result.to_json(),
439
- },
440
- }
441
- )
442
- if verbosity == 2:
443
- _print_code("Initial code and tests:", code, test)
444
- _LOGGER.info(
445
- f"Initial code execution result:\n{result.text(include_logs=True)}"
446
- )
91
+ return return_str
447
92
 
448
- count = 0
449
- new_working_memory: List[Dict[str, str]] = []
450
- while not result.success and count < max_retries:
451
- if verbosity == 2:
452
- _LOGGER.info(f"Start debugging attempt {count + 1}")
453
- code, test, result = debug_code(
454
- working_memory,
455
- debugger,
456
- code_interpreter,
457
- code,
458
- test,
459
- result,
460
- new_working_memory,
461
- log_progress,
462
- verbosity,
463
- )
464
- count += 1
465
-
466
- if verbosity >= 1:
467
- _print_code("Final code and tests:", code, test)
468
-
469
- return {
470
- "code": code,
471
- "test": test,
472
- "success": result.success,
473
- "test_result": result,
474
- "working_memory": new_working_memory,
475
- }
476
-
477
-
478
- @traceable
479
- def debug_code(
480
- working_memory: List[Dict[str, str]],
481
- debugger: LMM,
482
- code_interpreter: CodeInterpreter,
483
- code: str,
484
- test: str,
485
- result: Execution,
486
- new_working_memory: List[Dict[str, str]],
487
- log_progress: Callable[[Dict[str, Any]], None],
488
- verbosity: int = 0,
489
- ) -> tuple[str, str, Execution]:
490
- log_progress(
491
- {
492
- "type": "code",
493
- "status": "started",
494
- }
495
- )
496
93
 
497
- fixed_code_and_test = {"code": "", "test": "", "reflections": ""}
498
- success = False
499
- count = 0
500
- while not success and count < 3:
501
- try:
502
- fixed_code_and_test = extract_json(
503
- debugger(
504
- FIX_BUG.format(
505
- code=code,
506
- tests=test,
507
- result="\n".join(result.text().splitlines()[-100:]),
508
- feedback=format_memory(working_memory + new_working_memory),
509
- )
510
- )
511
- )
512
- success = True
513
- except Exception as e:
514
- _LOGGER.exception(f"Error while extracting JSON: {e}")
515
-
516
- count += 1
517
-
518
- old_code = code
519
- old_test = test
520
-
521
- if fixed_code_and_test["code"].strip() != "":
522
- code = extract_code(fixed_code_and_test["code"])
523
- if fixed_code_and_test["test"].strip() != "":
524
- test = extract_code(fixed_code_and_test["test"])
525
-
526
- new_working_memory.append(
527
- {
528
- "code": f"{code}\n{test}",
529
- "feedback": fixed_code_and_test["reflections"],
530
- "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
531
- }
532
- )
533
- log_progress(
534
- {
535
- "type": "code",
536
- "status": "running",
537
- "payload": {
538
- "code": DefaultImports.prepend_imports(code),
539
- "test": test,
540
- },
541
- }
542
- )
543
-
544
- result = code_interpreter.exec_isolation(
545
- f"{DefaultImports.to_code_string()}\n{code}\n{test}"
546
- )
547
- log_progress(
548
- {
549
- "type": "code",
550
- "status": "completed" if result.success else "failed",
551
- "payload": {
552
- "code": DefaultImports.prepend_imports(code),
553
- "test": test,
554
- "result": result.to_json(),
555
- },
556
- }
557
- )
558
- if verbosity == 2:
559
- _print_code("Code and test after attempted fix:", code, test)
560
- _LOGGER.info(
561
- f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
562
- )
563
-
564
- return code, test, result
565
-
566
-
567
- def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
568
- _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
569
- _CONSOLE.print("=" * 30 + " Code " + "=" * 30)
570
- _CONSOLE.print(
571
- Syntax(
572
- DefaultImports.prepend_imports(code),
573
- "python",
574
- theme="gruvbox-dark",
575
- line_numbers=True,
576
- )
577
- )
578
- if test:
579
- _CONSOLE.print("=" * 30 + " Test " + "=" * 30)
580
- _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
581
-
582
-
583
- def retrieve_tools(
584
- plans: Dict[str, List[Dict[str, str]]],
585
- tool_recommender: Sim,
586
- verbosity: int = 0,
587
- ) -> Tuple[Dict[str, str], Dict[str, List[Dict[str, str]]]]:
588
- tool_info = []
589
- tool_desc = []
590
- tool_lists: Dict[str, List[Dict[str, str]]] = {}
591
- for k, plan in plans.items():
592
- tool_lists[k] = []
593
- for task in plan:
594
- tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
595
- tool_info.extend([e["doc"] for e in tools])
596
- tool_desc.extend([e["desc"] for e in tools])
597
- tool_lists[k].extend(
598
- {
599
- "plan": task["instructions"] if index == 0 else "",
600
- "tool": e["desc"].strip().split()[0],
601
- "documentation": e["doc"],
602
- }
603
- for index, e in enumerate(tools)
604
- )
605
-
606
- if verbosity == 2:
607
- tool_desc_str = "\n".join(set(tool_desc))
608
- _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
609
-
610
- tool_lists_unique = {}
611
- for k in tool_lists:
612
- tool_lists_unique[k] = "\n\n".join(
613
- set(e["documentation"] for e in tool_lists[k])
614
- )
615
- all_tools = "\n\n".join(set(tool_info))
616
- tool_lists_unique["all"] = all_tools
617
- return tool_lists_unique, tool_lists
94
+ def parse_execution(response: str) -> Optional[str]:
95
+ code = None
96
+ if "<execute_python>" in response:
97
+ code = response[response.find("<execute_python>") + len("<execute_python>") :]
98
+ code = code[: code.find("</execute_python>")]
99
+ return code
618
100
 
619
101
 
620
102
  class VisionAgent(Agent):
621
- """Vision Agent is an agentic framework that can output code based on a user
622
- request. It can plan tasks, retrieve relevant tools, write code, write tests and
623
- reflect on failed test cases to debug code. It is inspired by AgentCoder
624
- https://arxiv.org/abs/2312.13010 and Data Interpeter
625
- https://arxiv.org/abs/2402.18679
103
+ """Vision Agent is an agent that can chat with the user and call tools or other
104
+ agents to generate code for it. Vision Agent uses python code to execute actions for
105
+ the user. Vision Agent is inspired by by OpenDev
106
+ https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030
626
107
 
627
108
  Example
628
109
  -------
629
- >>> from vision_agent import VisionAgent
110
+ >>> from vision_agent.agent import VisionAgent
630
111
  >>> agent = VisionAgent()
631
- >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
112
+ >>> resp = agent("Hello")
113
+ >>> resp.append({"role": "user", "content": "Can you write a function that counts dogs?", "media": ["dog.jpg"]})
114
+ >>> resp = agent(resp)
632
115
  """
633
116
 
634
117
  def __init__(
635
118
  self,
636
- planner: Optional[LMM] = None,
637
- coder: Optional[LMM] = None,
638
- tester: Optional[LMM] = None,
639
- debugger: Optional[LMM] = None,
640
- tool_recommender: Optional[Sim] = None,
119
+ agent: Optional[LMM] = None,
641
120
  verbosity: int = 0,
642
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
643
121
  code_sandbox_runtime: Optional[str] = None,
644
122
  ) -> None:
645
- """Initialize the Vision Agent.
646
-
647
- Parameters:
648
- planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
649
- coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
650
- tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
651
- debugger (Optional[LMM]): The debugger model to
652
- tool_recommender (Optional[Sim]): The tool recommender model to use.
653
- verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
654
- highest verbosity level which will output all intermediate debugging
655
- code.
656
- report_progress_callback: a callback to report the progress of the agent.
657
- This is useful for streaming logs in a web application where multiple
658
- VisionAgent instances are running in parallel. This callback ensures
659
- that the progress are not mixed up.
660
- code_sandbox_runtime: the code sandbox runtime to use. A code sandbox is
661
- used to run the generated code. It can be one of the following
662
- values: None, "local" or "e2b". If None, Vision Agent will read the
663
- value from the environment variable CODE_SANDBOX_RUNTIME. If it's
664
- also None, the local python runtime environment will be used.
665
- """
666
-
667
- self.planner = (
668
- OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
669
- )
670
- self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
671
- self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
672
- self.debugger = (
673
- OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
674
- )
675
-
676
- self.tool_recommender = (
677
- Sim(T.TOOLS_DF, sim_key="desc")
678
- if tool_recommender is None
679
- else tool_recommender
123
+ self.agent = (
124
+ OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent
680
125
  )
126
+ self.max_iterations = 100
681
127
  self.verbosity = verbosity
682
- self.max_retries = 2
683
- self.report_progress_callback = report_progress_callback
684
128
  self.code_sandbox_runtime = code_sandbox_runtime
129
+ if self.verbosity >= 1:
130
+ _LOGGER.setLevel(logging.INFO)
685
131
 
686
132
  def __call__(
687
133
  self,
688
134
  input: Union[str, List[Message]],
689
135
  media: Optional[Union[str, Path]] = None,
690
136
  ) -> str:
691
- """Chat with Vision Agent and return intermediate information regarding the task.
137
+ """Chat with VisionAgent and get the conversation response.
692
138
 
693
139
  Parameters:
694
- input (Union[List[Dict[str, str]], str]): A conversation in the format of
695
- [{"role": "user", "content": "describe your task here..."}] or a string
696
- of just the contents.
140
+ input (Union[str, List[Message]): A conversation in the format of
141
+ [{"role": "user", "content": "describe your task here..."}, ...] or a
142
+ string of just the contents.
697
143
  media (Optional[Union[str, Path]]): The media file to be used in the task.
698
144
 
699
145
  Returns:
700
- str: The code output by the Vision Agent.
146
+ str: The conversation response.
701
147
  """
702
-
703
148
  if isinstance(input, str):
704
149
  input = [{"role": "user", "content": input}]
705
150
  if media is not None:
706
151
  input[0]["media"] = [media]
707
- results = self.chat_with_workflow(input)
708
- results.pop("working_memory")
152
+ results = self.chat_with_code(input)
709
153
  return results # type: ignore
710
154
 
711
- @traceable
712
- def chat_with_workflow(
155
+ def chat_with_code(
713
156
  self,
714
157
  chat: List[Message],
715
- test_multi_plan: bool = True,
716
- display_visualization: bool = False,
717
- ) -> Dict[str, Any]:
718
- """Chat with Vision Agent and return intermediate information regarding the task.
158
+ ) -> List[Message]:
159
+ """Chat with VisionAgent, it will use code to execute actions to accomplish
160
+ its tasks.
719
161
 
720
162
  Parameters:
721
- chat (List[MediaChatItem]): A conversation
163
+ chat (List[Message]): A conversation
722
164
  in the format of:
723
165
  [{"role": "user", "content": "describe your task here..."}]
724
166
  or if it contains media files, it should be in the format of:
725
167
  [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
726
- display_visualization (bool): If True, it opens a new window locally to
727
- show the image(s) created by visualization code (if there is any).
728
168
 
729
169
  Returns:
730
- Dict[str, Any]: A dictionary containing the code, test, test result, plan,
731
- and working memory of the agent.
170
+ List[Message]: The conversation response.
732
171
  """
733
172
 
734
173
  if not chat:
735
- raise ValueError("Chat cannot be empty.")
174
+ raise ValueError("chat cannot be empty")
736
175
 
737
- # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
738
176
  with CodeInterpreterFactory.new_instance(
739
177
  code_sandbox_runtime=self.code_sandbox_runtime
740
178
  ) as code_interpreter:
741
- chat = copy.deepcopy(chat)
179
+ orig_chat = copy.deepcopy(chat)
180
+ int_chat = copy.deepcopy(chat)
742
181
  media_list = []
743
- for chat_i in chat:
182
+ for chat_i in int_chat:
744
183
  if "media" in chat_i:
745
184
  for media in chat_i["media"]:
746
185
  media = code_interpreter.upload_file(media)
@@ -759,180 +198,33 @@ class VisionAgent(Agent):
759
198
  if "media" in c
760
199
  else {"role": c["role"], "content": c["content"]}
761
200
  )
762
- for c in chat
201
+ for c in int_chat
763
202
  ],
764
203
  )
765
204
 
766
- code = ""
767
- test = ""
768
- working_memory: List[Dict[str, str]] = []
769
- results = {"code": "", "test": "", "plan": []}
770
- plan = []
771
-
772
- self.log_progress(
773
- {
774
- "type": "log",
775
- "log_content": "Creating plans",
776
- "status": "started",
777
- }
778
- )
779
- plans = write_plans(
780
- int_chat,
781
- T.TOOL_DESCRIPTIONS,
782
- format_memory(working_memory),
783
- self.planner,
784
- )
785
-
786
- if self.verbosity >= 1 and test_multi_plan:
787
- for p in plans:
788
- _LOGGER.info(
789
- f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
790
- )
205
+ finished = False
206
+ iterations = 0
207
+ while not finished and iterations < self.max_iterations:
208
+ response = run_conversation(self.agent, int_chat)
209
+ if self.verbosity >= 1:
210
+ _LOGGER.info(response)
211
+ int_chat.append({"role": "assistant", "content": str(response)})
212
+ orig_chat.append({"role": "assistant", "content": str(response)})
791
213
 
792
- tool_infos, tool_lists = retrieve_tools(
793
- plans,
794
- self.tool_recommender,
795
- self.verbosity,
796
- )
214
+ if response["let_user_respond"]:
215
+ break
797
216
 
798
- if test_multi_plan:
799
- self.log_progress(
800
- {
801
- "type": "log",
802
- "log_content": "Creating plans",
803
- "status": "completed",
804
- "payload": tool_lists,
805
- }
806
- )
807
-
808
- best_plan, best_tool_info, tool_output_str = pick_plan(
809
- int_chat,
810
- plans,
811
- tool_infos,
812
- self.coder,
813
- code_interpreter,
814
- test_multi_plan,
815
- self.log_progress,
816
- verbosity=self.verbosity,
817
- )
217
+ code_action = parse_execution(response["response"])
818
218
 
819
- if self.verbosity >= 1:
820
- _LOGGER.info(
821
- f"Picked best plan:\n{tabulate(tabular_data=best_plan, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
822
- )
823
-
824
- results = write_and_test_code(
825
- chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
826
- plan="\n-" + "\n-".join([e["instructions"] for e in best_plan]),
827
- tool_info=best_tool_info,
828
- tool_output=tool_output_str,
829
- tool_utils=T.UTILITIES_DOCSTRING,
830
- working_memory=working_memory,
831
- coder=self.coder,
832
- tester=self.tester,
833
- debugger=self.debugger,
834
- code_interpreter=code_interpreter,
835
- log_progress=self.log_progress,
836
- verbosity=self.verbosity,
837
- media=media_list,
838
- )
839
- success = cast(bool, results["success"])
840
- code = cast(str, results["code"])
841
- test = cast(str, results["test"])
842
- working_memory.extend(results["working_memory"]) # type: ignore
843
- plan.append({"code": code, "test": test, "plan": best_plan})
844
-
845
- execution_result = cast(Execution, results["test_result"])
846
- self.log_progress(
847
- {
848
- "type": "final_code",
849
- "status": "completed" if success else "failed",
850
- "payload": {
851
- "code": DefaultImports.prepend_imports(code),
852
- "test": test,
853
- "result": execution_result.to_json(),
854
- },
855
- }
856
- )
219
+ if code_action is not None:
220
+ obs = run_code_action(code_action, code_interpreter)
221
+ if self.verbosity >= 1:
222
+ _LOGGER.info(obs)
223
+ int_chat.append({"role": "observation", "content": obs})
224
+ orig_chat.append({"role": "observation", "content": obs})
857
225
 
858
- if display_visualization:
859
- for res in execution_result.results:
860
- if res.png:
861
- b64_to_pil(res.png).show()
862
- if res.mp4:
863
- play_video(res.mp4)
864
-
865
- return {
866
- "code": DefaultImports.prepend_imports(code),
867
- "test": test,
868
- "test_result": execution_result,
869
- "plan": plan,
870
- "working_memory": working_memory,
871
- }
226
+ iterations += 1
227
+ return orig_chat
872
228
 
873
229
  def log_progress(self, data: Dict[str, Any]) -> None:
874
- if self.report_progress_callback is not None:
875
- self.report_progress_callback(data)
876
-
877
-
878
- class AzureVisionAgent(VisionAgent):
879
- """Vision Agent that uses Azure OpenAI APIs for planning, coding, testing.
880
-
881
- Pre-requisites:
882
- 1. Set the environment variable AZURE_OPENAI_API_KEY to your Azure OpenAI API key.
883
- 2. Set the environment variable AZURE_OPENAI_ENDPOINT to your Azure OpenAI endpoint.
884
-
885
- Example
886
- -------
887
- >>> from vision_agent import AzureVisionAgent
888
- >>> agent = AzureVisionAgent()
889
- >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
890
- """
891
-
892
- def __init__(
893
- self,
894
- planner: Optional[LMM] = None,
895
- coder: Optional[LMM] = None,
896
- tester: Optional[LMM] = None,
897
- debugger: Optional[LMM] = None,
898
- tool_recommender: Optional[Sim] = None,
899
- verbosity: int = 0,
900
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
901
- ) -> None:
902
- """Initialize the Vision Agent.
903
-
904
- Parameters:
905
- planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
906
- coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
907
- tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
908
- debugger (Optional[LMM]): The debugger model to
909
- tool_recommender (Optional[Sim]): The tool recommender model to use.
910
- verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
911
- highest verbosity level which will output all intermediate debugging
912
- code.
913
- report_progress_callback: a callback to report the progress of the agent.
914
- This is useful for streaming logs in a web application where multiple
915
- VisionAgent instances are running in parallel. This callback ensures
916
- that the progress are not mixed up.
917
- """
918
- super().__init__(
919
- planner=(
920
- AzureOpenAILMM(temperature=0.0, json_mode=True)
921
- if planner is None
922
- else planner
923
- ),
924
- coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
925
- tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
926
- debugger=(
927
- AzureOpenAILMM(temperature=0.0, json_mode=True)
928
- if debugger is None
929
- else debugger
930
- ),
931
- tool_recommender=(
932
- AzureSim(T.TOOLS_DF, sim_key="desc")
933
- if tool_recommender is None
934
- else tool_recommender
935
- ),
936
- verbosity=verbosity,
937
- report_progress_callback=report_progress_callback,
938
- )
230
+ pass