vision-agent 0.2.161__py3-none-any.whl → 0.2.163__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,583 @@
1
+ import copy
2
+ import logging
3
+ from json import JSONDecodeError
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
+
7
+ from pydantic import BaseModel
8
+ from tabulate import tabulate
9
+
10
+ import vision_agent.tools as T
11
+ from vision_agent.agent import Agent
12
+ from vision_agent.agent.agent_utils import (
13
+ _MAX_TABULATE_COL_WIDTH,
14
+ DefaultImports,
15
+ extract_code,
16
+ extract_json,
17
+ format_memory,
18
+ format_plans,
19
+ print_code,
20
+ )
21
+ from vision_agent.agent.vision_agent_planner_prompts import (
22
+ PICK_PLAN,
23
+ PLAN,
24
+ PREVIOUS_FAILED,
25
+ TEST_PLANS,
26
+ USER_REQ,
27
+ )
28
+ from vision_agent.lmm import (
29
+ LMM,
30
+ AnthropicLMM,
31
+ AzureOpenAILMM,
32
+ Message,
33
+ OllamaLMM,
34
+ OpenAILMM,
35
+ )
36
+ from vision_agent.utils.execute import (
37
+ CodeInterpreter,
38
+ CodeInterpreterFactory,
39
+ Execution,
40
+ )
41
+ from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
42
+
43
+ _LOGGER = logging.getLogger(__name__)
44
+
45
+
46
+ class PlanContext(BaseModel):
47
+ plans: Dict[str, Dict[str, Union[str, List[str]]]]
48
+ best_plan: str
49
+ plan_thoughts: str
50
+ tool_output: str
51
+ tool_doc: str
52
+ test_results: Optional[Execution]
53
+
54
+
55
+ def retrieve_tools(
56
+ plans: Dict[str, Dict[str, Any]],
57
+ tool_recommender: Sim,
58
+ log_progress: Callable[[Dict[str, Any]], None],
59
+ verbosity: int = 0,
60
+ ) -> Dict[str, str]:
61
+ log_progress(
62
+ {
63
+ "type": "log",
64
+ "log_content": ("Retrieving tools for each plan"),
65
+ "status": "started",
66
+ }
67
+ )
68
+ tool_info = []
69
+ tool_desc = []
70
+ tool_lists: Dict[str, List[Dict[str, str]]] = {}
71
+ for k, plan in plans.items():
72
+ tool_lists[k] = []
73
+ for task in plan["instructions"]:
74
+ tools = tool_recommender.top_k(task, k=2, thresh=0.3)
75
+ tool_info.extend([e["doc"] for e in tools])
76
+ tool_desc.extend([e["desc"] for e in tools])
77
+ tool_lists[k].extend(
78
+ {"description": e["desc"], "documentation": e["doc"]} for e in tools
79
+ )
80
+
81
+ if verbosity == 2:
82
+ tool_desc_str = "\n".join(set(tool_desc))
83
+ _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
84
+
85
+ tool_lists_unique = {}
86
+ for k in tool_lists:
87
+ tool_lists_unique[k] = "\n\n".join(
88
+ set(e["documentation"] for e in tool_lists[k])
89
+ )
90
+ all_tools = "\n\n".join(set(tool_info))
91
+ tool_lists_unique["all"] = all_tools
92
+ return tool_lists_unique
93
+
94
+
95
+ def _check_plan_format(plan: Dict[str, Any]) -> bool:
96
+ if not isinstance(plan, dict):
97
+ return False
98
+
99
+ for k in plan:
100
+ if "thoughts" not in plan[k] or "instructions" not in plan[k]:
101
+ return False
102
+ if not isinstance(plan[k]["instructions"], list):
103
+ return False
104
+ return True
105
+
106
+
107
+ def write_plans(
108
+ chat: List[Message], tool_desc: str, working_memory: str, model: LMM
109
+ ) -> Dict[str, Any]:
110
+ chat = copy.deepcopy(chat)
111
+ if chat[-1]["role"] != "user":
112
+ raise ValueError("Last message in chat must be from user")
113
+
114
+ user_request = chat[-1]["content"]
115
+ context = USER_REQ.format(user_request=user_request)
116
+ prompt = PLAN.format(
117
+ context=context,
118
+ tool_desc=tool_desc,
119
+ feedback=working_memory,
120
+ )
121
+ chat[-1]["content"] = prompt
122
+ plans = extract_json(model(chat, stream=False)) # type: ignore
123
+
124
+ count = 0
125
+ while not _check_plan_format(plans) and count < 3:
126
+ _LOGGER.info("Invalid plan format. Retrying.")
127
+ plans = extract_json(model(chat, stream=False)) # type: ignore
128
+ count += 1
129
+ if count == 3:
130
+ raise ValueError("Failed to generate valid plans after 3 attempts.")
131
+ return plans
132
+
133
+
134
+ def write_and_exec_plan_tests(
135
+ plans: Dict[str, Any],
136
+ tool_info: str,
137
+ media: List[str],
138
+ model: LMM,
139
+ log_progress: Callable[[Dict[str, Any]], None],
140
+ code_interpreter: CodeInterpreter,
141
+ verbosity: int = 0,
142
+ max_retries: int = 3,
143
+ ) -> Tuple[str, Execution]:
144
+
145
+ plan_str = format_plans(plans)
146
+ prompt = TEST_PLANS.format(
147
+ docstring=tool_info, plans=plan_str, previous_attempts="", media=media
148
+ )
149
+
150
+ code = extract_code(model(prompt, stream=False)) # type: ignore
151
+ log_progress(
152
+ {
153
+ "type": "log",
154
+ "log_content": "Executing code to test plans",
155
+ "code": DefaultImports.prepend_imports(code),
156
+ "status": "running",
157
+ }
158
+ )
159
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
160
+ # Because of the way we trace function calls the trace information ends up in the
161
+ # results. We don't want to show this info to the LLM so we don't include it in the
162
+ # tool_output_str.
163
+ tool_output_str = tool_output.text(include_results=False).strip()
164
+
165
+ if verbosity == 2:
166
+ print_code("Initial code and tests:", code)
167
+ _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
168
+
169
+ log_progress(
170
+ {
171
+ "type": "log",
172
+ "log_content": (
173
+ "Code execution succeeded"
174
+ if tool_output.success
175
+ else "Code execution failed"
176
+ ),
177
+ "code": DefaultImports.prepend_imports(code),
178
+ # "payload": tool_output.to_json(),
179
+ "status": "completed" if tool_output.success else "failed",
180
+ }
181
+ )
182
+
183
+ # retry if the tool output is empty or code fails
184
+ count = 0
185
+ tool_output_str = tool_output.text(include_results=False).strip()
186
+ while (
187
+ not tool_output.success
188
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
189
+ ) and count < max_retries:
190
+ prompt = TEST_PLANS.format(
191
+ docstring=tool_info,
192
+ plans=plan_str,
193
+ previous_attempts=PREVIOUS_FAILED.format(
194
+ code=code, error="\n".join(tool_output_str.splitlines()[-50:])
195
+ ),
196
+ media=media,
197
+ )
198
+ log_progress(
199
+ {
200
+ "type": "log",
201
+ "log_content": "Retrying code to test plans",
202
+ "status": "running",
203
+ "code": DefaultImports.prepend_imports(code),
204
+ }
205
+ )
206
+ code = extract_code(model(prompt, stream=False)) # type: ignore
207
+ tool_output = code_interpreter.exec_isolation(
208
+ DefaultImports.prepend_imports(code)
209
+ )
210
+ log_progress(
211
+ {
212
+ "type": "log",
213
+ "log_content": (
214
+ "Code execution succeeded"
215
+ if tool_output.success
216
+ else "Code execution failed"
217
+ ),
218
+ "code": DefaultImports.prepend_imports(code),
219
+ # "payload": tool_output.to_json(),
220
+ "status": "completed" if tool_output.success else "failed",
221
+ }
222
+ )
223
+ tool_output_str = tool_output.text(include_results=False).strip()
224
+
225
+ if verbosity == 2:
226
+ print_code("Code and test after attempted fix:", code)
227
+ _LOGGER.info(f"Code execution result after attempt {count + 1}")
228
+ _LOGGER.info(f"{tool_output_str}")
229
+
230
+ count += 1
231
+
232
+ return code, tool_output
233
+
234
+
235
+ def write_plan_thoughts(
236
+ chat: List[Message],
237
+ plans: Dict[str, Any],
238
+ tool_output_str: str,
239
+ model: LMM,
240
+ max_retries: int = 3,
241
+ ) -> Dict[str, str]:
242
+ user_req = chat[-1]["content"]
243
+ context = USER_REQ.format(user_request=user_req)
244
+ # because the tool picker model gets the image as well, we have to be careful with
245
+ # how much text we send it, so we truncate the tool output to 20,000 characters
246
+ prompt = PICK_PLAN.format(
247
+ context=context,
248
+ plans=format_plans(plans),
249
+ tool_output=tool_output_str[:20_000],
250
+ )
251
+ chat[-1]["content"] = prompt
252
+ count = 0
253
+
254
+ plan_thoughts = None
255
+ while plan_thoughts is None and count < max_retries:
256
+ try:
257
+ plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
258
+ except JSONDecodeError as e:
259
+ _LOGGER.exception(
260
+ f"Error while extracting JSON during picking best plan {str(e)}"
261
+ )
262
+ pass
263
+ count += 1
264
+
265
+ if (
266
+ plan_thoughts is None
267
+ or "best_plan" not in plan_thoughts
268
+ or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
269
+ ):
270
+ _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
271
+ plan_thoughts = {"best_plan": list(plans.keys())[0]}
272
+
273
+ if "thoughts" not in plan_thoughts:
274
+ plan_thoughts["thoughts"] = ""
275
+ return plan_thoughts
276
+
277
+
278
+ def pick_plan(
279
+ chat: List[Message],
280
+ plans: Dict[str, Any],
281
+ tool_info: str,
282
+ model: LMM,
283
+ code_interpreter: CodeInterpreter,
284
+ media: List[str],
285
+ log_progress: Callable[[Dict[str, Any]], None],
286
+ verbosity: int = 0,
287
+ max_retries: int = 3,
288
+ ) -> Tuple[Dict[str, str], str, Execution]:
289
+ log_progress(
290
+ {
291
+ "type": "log",
292
+ "log_content": "Generating code to pick the best plan",
293
+ "status": "started",
294
+ }
295
+ )
296
+
297
+ chat = copy.deepcopy(chat)
298
+ if chat[-1]["role"] != "user":
299
+ raise ValueError("Last chat message must be from the user.")
300
+
301
+ code, tool_output = write_and_exec_plan_tests(
302
+ plans,
303
+ tool_info,
304
+ media,
305
+ model,
306
+ log_progress,
307
+ code_interpreter,
308
+ verbosity,
309
+ max_retries,
310
+ )
311
+
312
+ if verbosity >= 1:
313
+ print_code("Final code:", code)
314
+
315
+ plan_thoughts = write_plan_thoughts(
316
+ chat,
317
+ plans,
318
+ tool_output.text(include_results=False).strip(),
319
+ model,
320
+ max_retries,
321
+ )
322
+
323
+ if verbosity >= 1:
324
+ _LOGGER.info(f"Best plan:\n{plan_thoughts}")
325
+ log_progress(
326
+ {
327
+ "type": "log",
328
+ "log_content": "Picked best plan",
329
+ "status": "completed",
330
+ "payload": plans[plan_thoughts["best_plan"]],
331
+ }
332
+ )
333
+ return plan_thoughts, code, tool_output
334
+
335
+
336
+ class VisionAgentPlanner(Agent):
337
+ def __init__(
338
+ self,
339
+ planner: Optional[LMM] = None,
340
+ tool_recommender: Optional[Sim] = None,
341
+ verbosity: int = 0,
342
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
343
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
344
+ ) -> None:
345
+ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
346
+ self.verbosity = verbosity
347
+ if self.verbosity > 0:
348
+ _LOGGER.setLevel(logging.INFO)
349
+
350
+ self.tool_recommender = (
351
+ Sim(T.TOOLS_DF, sim_key="desc")
352
+ if tool_recommender is None
353
+ else tool_recommender
354
+ )
355
+ self.report_progress_callback = report_progress_callback
356
+ self.code_interpreter = code_interpreter
357
+
358
+ def __call__(
359
+ self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
360
+ ) -> str:
361
+ if isinstance(input, str):
362
+ input = [{"role": "user", "content": input}]
363
+ if media is not None:
364
+ input[0]["media"] = [media]
365
+ planning_context = self.generate_plan(input)
366
+ return str(planning_context.plans[planning_context.best_plan])
367
+
368
+ def generate_plan(
369
+ self,
370
+ chat: List[Message],
371
+ test_multi_plan: bool = True,
372
+ custom_tool_names: Optional[List[str]] = None,
373
+ code_interpreter: Optional[CodeInterpreter] = None,
374
+ ) -> PlanContext:
375
+ if not chat:
376
+ raise ValueError("Chat cannot be empty")
377
+
378
+ code_interpreter = (
379
+ code_interpreter
380
+ if code_interpreter is not None
381
+ else (
382
+ self.code_interpreter
383
+ if not isinstance(self.code_interpreter, str)
384
+ else CodeInterpreterFactory.new_instance(self.code_interpreter)
385
+ )
386
+ )
387
+ code_interpreter = cast(CodeInterpreter, code_interpreter)
388
+ with code_interpreter:
389
+ chat = copy.deepcopy(chat)
390
+ media_list = []
391
+ for chat_i in chat:
392
+ if "media" in chat_i:
393
+ for media in chat_i["media"]:
394
+ media = (
395
+ media
396
+ if type(media) is str
397
+ and media.startswith(("http", "https"))
398
+ else code_interpreter.upload_file(cast(str, media))
399
+ )
400
+ chat_i["content"] += f" Media name {media}" # type: ignore
401
+ media_list.append(str(media))
402
+
403
+ int_chat = cast(
404
+ List[Message],
405
+ [
406
+ (
407
+ {
408
+ "role": c["role"],
409
+ "content": c["content"],
410
+ "media": c["media"],
411
+ }
412
+ if "media" in c
413
+ else {"role": c["role"], "content": c["content"]}
414
+ )
415
+ for c in chat
416
+ ],
417
+ )
418
+
419
+ working_memory: List[Dict[str, str]] = []
420
+
421
+ plans = write_plans(
422
+ chat,
423
+ T.get_tool_descriptions_by_names(
424
+ custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
425
+ ),
426
+ format_memory(working_memory),
427
+ self.planner,
428
+ )
429
+ if self.verbosity >= 1:
430
+ for plan in plans:
431
+ plan_fixed = [
432
+ {"instructions": e} for e in plans[plan]["instructions"]
433
+ ]
434
+ _LOGGER.info(
435
+ f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
436
+ )
437
+
438
+ tool_docs = retrieve_tools(
439
+ plans,
440
+ self.tool_recommender,
441
+ self.log_progress,
442
+ self.verbosity,
443
+ )
444
+ if test_multi_plan:
445
+ plan_thoughts, code, tool_output = pick_plan(
446
+ int_chat,
447
+ plans,
448
+ tool_docs["all"],
449
+ self.planner,
450
+ code_interpreter,
451
+ media_list,
452
+ self.log_progress,
453
+ self.verbosity,
454
+ )
455
+ best_plan = plan_thoughts["best_plan"]
456
+ plan_thoughts_str = plan_thoughts["thoughts"]
457
+ tool_output_str = (
458
+ "```python\n"
459
+ + code
460
+ + "\n```\n"
461
+ + tool_output.text(include_results=False).strip()
462
+ )
463
+ else:
464
+ best_plan = list(plans.keys())[0]
465
+ tool_output_str = ""
466
+ plan_thoughts_str = ""
467
+ tool_output = None
468
+
469
+ if best_plan in plans and best_plan in tool_docs:
470
+ tool_doc = tool_docs[best_plan]
471
+ else:
472
+ if self.verbosity >= 1:
473
+ _LOGGER.warning(
474
+ f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
475
+ )
476
+ k = list(plans.keys())[0]
477
+ best_plan = k
478
+ tool_doc = tool_docs[k]
479
+
480
+ return PlanContext(
481
+ plans=plans,
482
+ best_plan=best_plan,
483
+ plan_thoughts=plan_thoughts_str,
484
+ tool_output=tool_output_str,
485
+ test_results=tool_output,
486
+ tool_doc=tool_doc,
487
+ )
488
+
489
+ def log_progress(self, log: Dict[str, Any]) -> None:
490
+ if self.report_progress_callback is not None:
491
+ self.report_progress_callback(log)
492
+
493
+
494
+ class AnthropicVisionAgentPlanner(VisionAgentPlanner):
495
+ def __init__(
496
+ self,
497
+ planner: Optional[LMM] = None,
498
+ tool_recommender: Optional[Sim] = None,
499
+ verbosity: int = 0,
500
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
501
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
502
+ ) -> None:
503
+ super().__init__(
504
+ planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
505
+ tool_recommender=tool_recommender,
506
+ verbosity=verbosity,
507
+ report_progress_callback=report_progress_callback,
508
+ code_interpreter=code_interpreter,
509
+ )
510
+
511
+
512
+ class OpenAIVisionAgentPlanner(VisionAgentPlanner):
513
+ def __init__(
514
+ self,
515
+ planner: Optional[LMM] = None,
516
+ tool_recommender: Optional[Sim] = None,
517
+ verbosity: int = 0,
518
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
519
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
520
+ ) -> None:
521
+ super().__init__(
522
+ planner=(
523
+ OpenAILMM(temperature=0.0, json_mode=True)
524
+ if planner is None
525
+ else planner
526
+ ),
527
+ tool_recommender=tool_recommender,
528
+ verbosity=verbosity,
529
+ report_progress_callback=report_progress_callback,
530
+ code_interpreter=code_interpreter,
531
+ )
532
+
533
+
534
+ class OllamaVisionAgentPlanner(VisionAgentPlanner):
535
+ def __init__(
536
+ self,
537
+ planner: Optional[LMM] = None,
538
+ tool_recommender: Optional[Sim] = None,
539
+ verbosity: int = 0,
540
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
541
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
542
+ ) -> None:
543
+ super().__init__(
544
+ planner=(
545
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
546
+ if planner is None
547
+ else planner
548
+ ),
549
+ tool_recommender=(
550
+ OllamaSim(T.TOOLS_DF, sim_key="desc")
551
+ if tool_recommender is None
552
+ else tool_recommender
553
+ ),
554
+ verbosity=verbosity,
555
+ report_progress_callback=report_progress_callback,
556
+ code_interpreter=code_interpreter,
557
+ )
558
+
559
+
560
+ class AzureVisionAgentPlanner(VisionAgentPlanner):
561
+ def __init__(
562
+ self,
563
+ planner: Optional[LMM] = None,
564
+ tool_recommender: Optional[Sim] = None,
565
+ verbosity: int = 0,
566
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
567
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
568
+ ) -> None:
569
+ super().__init__(
570
+ planner=(
571
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
572
+ if planner is None
573
+ else planner
574
+ ),
575
+ tool_recommender=(
576
+ AzureSim(T.TOOLS_DF, sim_key="desc")
577
+ if tool_recommender is None
578
+ else tool_recommender
579
+ ),
580
+ verbosity=verbosity,
581
+ report_progress_callback=report_progress_callback,
582
+ code_interpreter=code_interpreter,
583
+ )