vision-agent 0.2.161__py3-none-any.whl → 0.2.163__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,583 @@
1
+ import copy
2
+ import logging
3
+ from json import JSONDecodeError
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
+
7
+ from pydantic import BaseModel
8
+ from tabulate import tabulate
9
+
10
+ import vision_agent.tools as T
11
+ from vision_agent.agent import Agent
12
+ from vision_agent.agent.agent_utils import (
13
+ _MAX_TABULATE_COL_WIDTH,
14
+ DefaultImports,
15
+ extract_code,
16
+ extract_json,
17
+ format_memory,
18
+ format_plans,
19
+ print_code,
20
+ )
21
+ from vision_agent.agent.vision_agent_planner_prompts import (
22
+ PICK_PLAN,
23
+ PLAN,
24
+ PREVIOUS_FAILED,
25
+ TEST_PLANS,
26
+ USER_REQ,
27
+ )
28
+ from vision_agent.lmm import (
29
+ LMM,
30
+ AnthropicLMM,
31
+ AzureOpenAILMM,
32
+ Message,
33
+ OllamaLMM,
34
+ OpenAILMM,
35
+ )
36
+ from vision_agent.utils.execute import (
37
+ CodeInterpreter,
38
+ CodeInterpreterFactory,
39
+ Execution,
40
+ )
41
+ from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
42
+
43
+ _LOGGER = logging.getLogger(__name__)
44
+
45
+
46
+ class PlanContext(BaseModel):
47
+ plans: Dict[str, Dict[str, Union[str, List[str]]]]
48
+ best_plan: str
49
+ plan_thoughts: str
50
+ tool_output: str
51
+ tool_doc: str
52
+ test_results: Optional[Execution]
53
+
54
+
55
+ def retrieve_tools(
56
+ plans: Dict[str, Dict[str, Any]],
57
+ tool_recommender: Sim,
58
+ log_progress: Callable[[Dict[str, Any]], None],
59
+ verbosity: int = 0,
60
+ ) -> Dict[str, str]:
61
+ log_progress(
62
+ {
63
+ "type": "log",
64
+ "log_content": ("Retrieving tools for each plan"),
65
+ "status": "started",
66
+ }
67
+ )
68
+ tool_info = []
69
+ tool_desc = []
70
+ tool_lists: Dict[str, List[Dict[str, str]]] = {}
71
+ for k, plan in plans.items():
72
+ tool_lists[k] = []
73
+ for task in plan["instructions"]:
74
+ tools = tool_recommender.top_k(task, k=2, thresh=0.3)
75
+ tool_info.extend([e["doc"] for e in tools])
76
+ tool_desc.extend([e["desc"] for e in tools])
77
+ tool_lists[k].extend(
78
+ {"description": e["desc"], "documentation": e["doc"]} for e in tools
79
+ )
80
+
81
+ if verbosity == 2:
82
+ tool_desc_str = "\n".join(set(tool_desc))
83
+ _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
84
+
85
+ tool_lists_unique = {}
86
+ for k in tool_lists:
87
+ tool_lists_unique[k] = "\n\n".join(
88
+ set(e["documentation"] for e in tool_lists[k])
89
+ )
90
+ all_tools = "\n\n".join(set(tool_info))
91
+ tool_lists_unique["all"] = all_tools
92
+ return tool_lists_unique
93
+
94
+
95
+ def _check_plan_format(plan: Dict[str, Any]) -> bool:
96
+ if not isinstance(plan, dict):
97
+ return False
98
+
99
+ for k in plan:
100
+ if "thoughts" not in plan[k] or "instructions" not in plan[k]:
101
+ return False
102
+ if not isinstance(plan[k]["instructions"], list):
103
+ return False
104
+ return True
105
+
106
+
107
+ def write_plans(
108
+ chat: List[Message], tool_desc: str, working_memory: str, model: LMM
109
+ ) -> Dict[str, Any]:
110
+ chat = copy.deepcopy(chat)
111
+ if chat[-1]["role"] != "user":
112
+ raise ValueError("Last message in chat must be from user")
113
+
114
+ user_request = chat[-1]["content"]
115
+ context = USER_REQ.format(user_request=user_request)
116
+ prompt = PLAN.format(
117
+ context=context,
118
+ tool_desc=tool_desc,
119
+ feedback=working_memory,
120
+ )
121
+ chat[-1]["content"] = prompt
122
+ plans = extract_json(model(chat, stream=False)) # type: ignore
123
+
124
+ count = 0
125
+ while not _check_plan_format(plans) and count < 3:
126
+ _LOGGER.info("Invalid plan format. Retrying.")
127
+ plans = extract_json(model(chat, stream=False)) # type: ignore
128
+ count += 1
129
+ if count == 3:
130
+ raise ValueError("Failed to generate valid plans after 3 attempts.")
131
+ return plans
132
+
133
+
134
+ def write_and_exec_plan_tests(
135
+ plans: Dict[str, Any],
136
+ tool_info: str,
137
+ media: List[str],
138
+ model: LMM,
139
+ log_progress: Callable[[Dict[str, Any]], None],
140
+ code_interpreter: CodeInterpreter,
141
+ verbosity: int = 0,
142
+ max_retries: int = 3,
143
+ ) -> Tuple[str, Execution]:
144
+
145
+ plan_str = format_plans(plans)
146
+ prompt = TEST_PLANS.format(
147
+ docstring=tool_info, plans=plan_str, previous_attempts="", media=media
148
+ )
149
+
150
+ code = extract_code(model(prompt, stream=False)) # type: ignore
151
+ log_progress(
152
+ {
153
+ "type": "log",
154
+ "log_content": "Executing code to test plans",
155
+ "code": DefaultImports.prepend_imports(code),
156
+ "status": "running",
157
+ }
158
+ )
159
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
160
+ # Because of the way we trace function calls the trace information ends up in the
161
+ # results. We don't want to show this info to the LLM so we don't include it in the
162
+ # tool_output_str.
163
+ tool_output_str = tool_output.text(include_results=False).strip()
164
+
165
+ if verbosity == 2:
166
+ print_code("Initial code and tests:", code)
167
+ _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
168
+
169
+ log_progress(
170
+ {
171
+ "type": "log",
172
+ "log_content": (
173
+ "Code execution succeeded"
174
+ if tool_output.success
175
+ else "Code execution failed"
176
+ ),
177
+ "code": DefaultImports.prepend_imports(code),
178
+ # "payload": tool_output.to_json(),
179
+ "status": "completed" if tool_output.success else "failed",
180
+ }
181
+ )
182
+
183
+ # retry if the tool output is empty or code fails
184
+ count = 0
185
+ tool_output_str = tool_output.text(include_results=False).strip()
186
+ while (
187
+ not tool_output.success
188
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
189
+ ) and count < max_retries:
190
+ prompt = TEST_PLANS.format(
191
+ docstring=tool_info,
192
+ plans=plan_str,
193
+ previous_attempts=PREVIOUS_FAILED.format(
194
+ code=code, error="\n".join(tool_output_str.splitlines()[-50:])
195
+ ),
196
+ media=media,
197
+ )
198
+ log_progress(
199
+ {
200
+ "type": "log",
201
+ "log_content": "Retrying code to test plans",
202
+ "status": "running",
203
+ "code": DefaultImports.prepend_imports(code),
204
+ }
205
+ )
206
+ code = extract_code(model(prompt, stream=False)) # type: ignore
207
+ tool_output = code_interpreter.exec_isolation(
208
+ DefaultImports.prepend_imports(code)
209
+ )
210
+ log_progress(
211
+ {
212
+ "type": "log",
213
+ "log_content": (
214
+ "Code execution succeeded"
215
+ if tool_output.success
216
+ else "Code execution failed"
217
+ ),
218
+ "code": DefaultImports.prepend_imports(code),
219
+ # "payload": tool_output.to_json(),
220
+ "status": "completed" if tool_output.success else "failed",
221
+ }
222
+ )
223
+ tool_output_str = tool_output.text(include_results=False).strip()
224
+
225
+ if verbosity == 2:
226
+ print_code("Code and test after attempted fix:", code)
227
+ _LOGGER.info(f"Code execution result after attempt {count + 1}")
228
+ _LOGGER.info(f"{tool_output_str}")
229
+
230
+ count += 1
231
+
232
+ return code, tool_output
233
+
234
+
235
+ def write_plan_thoughts(
236
+ chat: List[Message],
237
+ plans: Dict[str, Any],
238
+ tool_output_str: str,
239
+ model: LMM,
240
+ max_retries: int = 3,
241
+ ) -> Dict[str, str]:
242
+ user_req = chat[-1]["content"]
243
+ context = USER_REQ.format(user_request=user_req)
244
+ # because the tool picker model gets the image as well, we have to be careful with
245
+ # how much text we send it, so we truncate the tool output to 20,000 characters
246
+ prompt = PICK_PLAN.format(
247
+ context=context,
248
+ plans=format_plans(plans),
249
+ tool_output=tool_output_str[:20_000],
250
+ )
251
+ chat[-1]["content"] = prompt
252
+ count = 0
253
+
254
+ plan_thoughts = None
255
+ while plan_thoughts is None and count < max_retries:
256
+ try:
257
+ plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
258
+ except JSONDecodeError as e:
259
+ _LOGGER.exception(
260
+ f"Error while extracting JSON during picking best plan {str(e)}"
261
+ )
262
+ pass
263
+ count += 1
264
+
265
+ if (
266
+ plan_thoughts is None
267
+ or "best_plan" not in plan_thoughts
268
+ or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
269
+ ):
270
+ _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
271
+ plan_thoughts = {"best_plan": list(plans.keys())[0]}
272
+
273
+ if "thoughts" not in plan_thoughts:
274
+ plan_thoughts["thoughts"] = ""
275
+ return plan_thoughts
276
+
277
+
278
+ def pick_plan(
279
+ chat: List[Message],
280
+ plans: Dict[str, Any],
281
+ tool_info: str,
282
+ model: LMM,
283
+ code_interpreter: CodeInterpreter,
284
+ media: List[str],
285
+ log_progress: Callable[[Dict[str, Any]], None],
286
+ verbosity: int = 0,
287
+ max_retries: int = 3,
288
+ ) -> Tuple[Dict[str, str], str, Execution]:
289
+ log_progress(
290
+ {
291
+ "type": "log",
292
+ "log_content": "Generating code to pick the best plan",
293
+ "status": "started",
294
+ }
295
+ )
296
+
297
+ chat = copy.deepcopy(chat)
298
+ if chat[-1]["role"] != "user":
299
+ raise ValueError("Last chat message must be from the user.")
300
+
301
+ code, tool_output = write_and_exec_plan_tests(
302
+ plans,
303
+ tool_info,
304
+ media,
305
+ model,
306
+ log_progress,
307
+ code_interpreter,
308
+ verbosity,
309
+ max_retries,
310
+ )
311
+
312
+ if verbosity >= 1:
313
+ print_code("Final code:", code)
314
+
315
+ plan_thoughts = write_plan_thoughts(
316
+ chat,
317
+ plans,
318
+ tool_output.text(include_results=False).strip(),
319
+ model,
320
+ max_retries,
321
+ )
322
+
323
+ if verbosity >= 1:
324
+ _LOGGER.info(f"Best plan:\n{plan_thoughts}")
325
+ log_progress(
326
+ {
327
+ "type": "log",
328
+ "log_content": "Picked best plan",
329
+ "status": "completed",
330
+ "payload": plans[plan_thoughts["best_plan"]],
331
+ }
332
+ )
333
+ return plan_thoughts, code, tool_output
334
+
335
+
336
+ class VisionAgentPlanner(Agent):
337
+ def __init__(
338
+ self,
339
+ planner: Optional[LMM] = None,
340
+ tool_recommender: Optional[Sim] = None,
341
+ verbosity: int = 0,
342
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
343
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
344
+ ) -> None:
345
+ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
346
+ self.verbosity = verbosity
347
+ if self.verbosity > 0:
348
+ _LOGGER.setLevel(logging.INFO)
349
+
350
+ self.tool_recommender = (
351
+ Sim(T.TOOLS_DF, sim_key="desc")
352
+ if tool_recommender is None
353
+ else tool_recommender
354
+ )
355
+ self.report_progress_callback = report_progress_callback
356
+ self.code_interpreter = code_interpreter
357
+
358
+ def __call__(
359
+ self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
360
+ ) -> str:
361
+ if isinstance(input, str):
362
+ input = [{"role": "user", "content": input}]
363
+ if media is not None:
364
+ input[0]["media"] = [media]
365
+ planning_context = self.generate_plan(input)
366
+ return str(planning_context.plans[planning_context.best_plan])
367
+
368
+ def generate_plan(
369
+ self,
370
+ chat: List[Message],
371
+ test_multi_plan: bool = True,
372
+ custom_tool_names: Optional[List[str]] = None,
373
+ code_interpreter: Optional[CodeInterpreter] = None,
374
+ ) -> PlanContext:
375
+ if not chat:
376
+ raise ValueError("Chat cannot be empty")
377
+
378
+ code_interpreter = (
379
+ code_interpreter
380
+ if code_interpreter is not None
381
+ else (
382
+ self.code_interpreter
383
+ if not isinstance(self.code_interpreter, str)
384
+ else CodeInterpreterFactory.new_instance(self.code_interpreter)
385
+ )
386
+ )
387
+ code_interpreter = cast(CodeInterpreter, code_interpreter)
388
+ with code_interpreter:
389
+ chat = copy.deepcopy(chat)
390
+ media_list = []
391
+ for chat_i in chat:
392
+ if "media" in chat_i:
393
+ for media in chat_i["media"]:
394
+ media = (
395
+ media
396
+ if type(media) is str
397
+ and media.startswith(("http", "https"))
398
+ else code_interpreter.upload_file(cast(str, media))
399
+ )
400
+ chat_i["content"] += f" Media name {media}" # type: ignore
401
+ media_list.append(str(media))
402
+
403
+ int_chat = cast(
404
+ List[Message],
405
+ [
406
+ (
407
+ {
408
+ "role": c["role"],
409
+ "content": c["content"],
410
+ "media": c["media"],
411
+ }
412
+ if "media" in c
413
+ else {"role": c["role"], "content": c["content"]}
414
+ )
415
+ for c in chat
416
+ ],
417
+ )
418
+
419
+ working_memory: List[Dict[str, str]] = []
420
+
421
+ plans = write_plans(
422
+ chat,
423
+ T.get_tool_descriptions_by_names(
424
+ custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
425
+ ),
426
+ format_memory(working_memory),
427
+ self.planner,
428
+ )
429
+ if self.verbosity >= 1:
430
+ for plan in plans:
431
+ plan_fixed = [
432
+ {"instructions": e} for e in plans[plan]["instructions"]
433
+ ]
434
+ _LOGGER.info(
435
+ f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
436
+ )
437
+
438
+ tool_docs = retrieve_tools(
439
+ plans,
440
+ self.tool_recommender,
441
+ self.log_progress,
442
+ self.verbosity,
443
+ )
444
+ if test_multi_plan:
445
+ plan_thoughts, code, tool_output = pick_plan(
446
+ int_chat,
447
+ plans,
448
+ tool_docs["all"],
449
+ self.planner,
450
+ code_interpreter,
451
+ media_list,
452
+ self.log_progress,
453
+ self.verbosity,
454
+ )
455
+ best_plan = plan_thoughts["best_plan"]
456
+ plan_thoughts_str = plan_thoughts["thoughts"]
457
+ tool_output_str = (
458
+ "```python\n"
459
+ + code
460
+ + "\n```\n"
461
+ + tool_output.text(include_results=False).strip()
462
+ )
463
+ else:
464
+ best_plan = list(plans.keys())[0]
465
+ tool_output_str = ""
466
+ plan_thoughts_str = ""
467
+ tool_output = None
468
+
469
+ if best_plan in plans and best_plan in tool_docs:
470
+ tool_doc = tool_docs[best_plan]
471
+ else:
472
+ if self.verbosity >= 1:
473
+ _LOGGER.warning(
474
+ f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
475
+ )
476
+ k = list(plans.keys())[0]
477
+ best_plan = k
478
+ tool_doc = tool_docs[k]
479
+
480
+ return PlanContext(
481
+ plans=plans,
482
+ best_plan=best_plan,
483
+ plan_thoughts=plan_thoughts_str,
484
+ tool_output=tool_output_str,
485
+ test_results=tool_output,
486
+ tool_doc=tool_doc,
487
+ )
488
+
489
+ def log_progress(self, log: Dict[str, Any]) -> None:
490
+ if self.report_progress_callback is not None:
491
+ self.report_progress_callback(log)
492
+
493
+
494
+ class AnthropicVisionAgentPlanner(VisionAgentPlanner):
495
+ def __init__(
496
+ self,
497
+ planner: Optional[LMM] = None,
498
+ tool_recommender: Optional[Sim] = None,
499
+ verbosity: int = 0,
500
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
501
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
502
+ ) -> None:
503
+ super().__init__(
504
+ planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
505
+ tool_recommender=tool_recommender,
506
+ verbosity=verbosity,
507
+ report_progress_callback=report_progress_callback,
508
+ code_interpreter=code_interpreter,
509
+ )
510
+
511
+
512
+ class OpenAIVisionAgentPlanner(VisionAgentPlanner):
513
+ def __init__(
514
+ self,
515
+ planner: Optional[LMM] = None,
516
+ tool_recommender: Optional[Sim] = None,
517
+ verbosity: int = 0,
518
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
519
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
520
+ ) -> None:
521
+ super().__init__(
522
+ planner=(
523
+ OpenAILMM(temperature=0.0, json_mode=True)
524
+ if planner is None
525
+ else planner
526
+ ),
527
+ tool_recommender=tool_recommender,
528
+ verbosity=verbosity,
529
+ report_progress_callback=report_progress_callback,
530
+ code_interpreter=code_interpreter,
531
+ )
532
+
533
+
534
+ class OllamaVisionAgentPlanner(VisionAgentPlanner):
535
+ def __init__(
536
+ self,
537
+ planner: Optional[LMM] = None,
538
+ tool_recommender: Optional[Sim] = None,
539
+ verbosity: int = 0,
540
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
541
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
542
+ ) -> None:
543
+ super().__init__(
544
+ planner=(
545
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
546
+ if planner is None
547
+ else planner
548
+ ),
549
+ tool_recommender=(
550
+ OllamaSim(T.TOOLS_DF, sim_key="desc")
551
+ if tool_recommender is None
552
+ else tool_recommender
553
+ ),
554
+ verbosity=verbosity,
555
+ report_progress_callback=report_progress_callback,
556
+ code_interpreter=code_interpreter,
557
+ )
558
+
559
+
560
+ class AzureVisionAgentPlanner(VisionAgentPlanner):
561
+ def __init__(
562
+ self,
563
+ planner: Optional[LMM] = None,
564
+ tool_recommender: Optional[Sim] = None,
565
+ verbosity: int = 0,
566
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
567
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
568
+ ) -> None:
569
+ super().__init__(
570
+ planner=(
571
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
572
+ if planner is None
573
+ else planner
574
+ ),
575
+ tool_recommender=(
576
+ AzureSim(T.TOOLS_DF, sim_key="desc")
577
+ if tool_recommender is None
578
+ else tool_recommender
579
+ ),
580
+ verbosity=verbosity,
581
+ report_progress_callback=report_progress_callback,
582
+ code_interpreter=code_interpreter,
583
+ )