vision-agent 0.2.161__py3-none-any.whl → 0.2.162__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,553 @@
1
+ import copy
2
+ import logging
3
+ from json import JSONDecodeError
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
+
7
+ from pydantic import BaseModel
8
+
9
+ import vision_agent.tools as T
10
+ from vision_agent.agent import Agent
11
+ from vision_agent.agent.agent_utils import (
12
+ DefaultImports,
13
+ extract_code,
14
+ extract_json,
15
+ format_memory,
16
+ format_plans,
17
+ print_code,
18
+ )
19
+ from vision_agent.agent.vision_agent_planner_prompts import (
20
+ PICK_PLAN,
21
+ PLAN,
22
+ PREVIOUS_FAILED,
23
+ TEST_PLANS,
24
+ USER_REQ,
25
+ )
26
+ from vision_agent.lmm import (
27
+ LMM,
28
+ AnthropicLMM,
29
+ AzureOpenAILMM,
30
+ Message,
31
+ OllamaLMM,
32
+ OpenAILMM,
33
+ )
34
+ from vision_agent.utils.execute import (
35
+ CodeInterpreter,
36
+ CodeInterpreterFactory,
37
+ Execution,
38
+ )
39
+ from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
40
+
41
+ _LOGGER = logging.getLogger(__name__)
42
+
43
+
44
+ class PlanContext(BaseModel):
45
+ plans: Dict[str, Dict[str, Union[str, List[str]]]]
46
+ best_plan: str
47
+ plan_thoughts: str
48
+ tool_output: str
49
+ tool_doc: str
50
+ test_results: Optional[Execution]
51
+
52
+
53
+ def retrieve_tools(
54
+ plans: Dict[str, Dict[str, Any]],
55
+ tool_recommender: Sim,
56
+ log_progress: Callable[[Dict[str, Any]], None],
57
+ verbosity: int = 0,
58
+ ) -> Dict[str, str]:
59
+ log_progress(
60
+ {
61
+ "type": "log",
62
+ "log_content": ("Retrieving tools for each plan"),
63
+ "status": "started",
64
+ }
65
+ )
66
+ tool_info = []
67
+ tool_desc = []
68
+ tool_lists: Dict[str, List[Dict[str, str]]] = {}
69
+ for k, plan in plans.items():
70
+ tool_lists[k] = []
71
+ for task in plan["instructions"]:
72
+ tools = tool_recommender.top_k(task, k=2, thresh=0.3)
73
+ tool_info.extend([e["doc"] for e in tools])
74
+ tool_desc.extend([e["desc"] for e in tools])
75
+ tool_lists[k].extend(
76
+ {"description": e["desc"], "documentation": e["doc"]} for e in tools
77
+ )
78
+
79
+ if verbosity == 2:
80
+ tool_desc_str = "\n".join(set(tool_desc))
81
+ _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
82
+
83
+ tool_lists_unique = {}
84
+ for k in tool_lists:
85
+ tool_lists_unique[k] = "\n\n".join(
86
+ set(e["documentation"] for e in tool_lists[k])
87
+ )
88
+ all_tools = "\n\n".join(set(tool_info))
89
+ tool_lists_unique["all"] = all_tools
90
+ return tool_lists_unique
91
+
92
+
93
+ def write_plans(
94
+ chat: List[Message], tool_desc: str, working_memory: str, model: LMM
95
+ ) -> Dict[str, Any]:
96
+ chat = copy.deepcopy(chat)
97
+ if chat[-1]["role"] != "user":
98
+ raise ValueError("Last message in chat must be from user")
99
+
100
+ user_request = chat[-1]["content"]
101
+ context = USER_REQ.format(user_request=user_request)
102
+ prompt = PLAN.format(
103
+ context=context,
104
+ tool_desc=tool_desc,
105
+ feedback=working_memory,
106
+ )
107
+ chat[-1]["content"] = prompt
108
+ return extract_json(model(chat, stream=False)) # type: ignore
109
+
110
+
111
+ def write_and_exec_plan_tests(
112
+ plans: Dict[str, Any],
113
+ tool_info: str,
114
+ media: List[str],
115
+ model: LMM,
116
+ log_progress: Callable[[Dict[str, Any]], None],
117
+ code_interpreter: CodeInterpreter,
118
+ verbosity: int = 0,
119
+ max_retries: int = 3,
120
+ ) -> Tuple[str, Execution]:
121
+
122
+ plan_str = format_plans(plans)
123
+ prompt = TEST_PLANS.format(
124
+ docstring=tool_info, plans=plan_str, previous_attempts="", media=media
125
+ )
126
+
127
+ code = extract_code(model(prompt, stream=False)) # type: ignore
128
+ log_progress(
129
+ {
130
+ "type": "log",
131
+ "log_content": "Executing code to test plans",
132
+ "code": DefaultImports.prepend_imports(code),
133
+ "status": "running",
134
+ }
135
+ )
136
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
137
+ # Because of the way we trace function calls the trace information ends up in the
138
+ # results. We don't want to show this info to the LLM so we don't include it in the
139
+ # tool_output_str.
140
+ tool_output_str = tool_output.text(include_results=False).strip()
141
+
142
+ if verbosity == 2:
143
+ print_code("Initial code and tests:", code)
144
+ _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
145
+
146
+ log_progress(
147
+ {
148
+ "type": "log",
149
+ "log_content": (
150
+ "Code execution succeeded"
151
+ if tool_output.success
152
+ else "Code execution failed"
153
+ ),
154
+ "code": DefaultImports.prepend_imports(code),
155
+ # "payload": tool_output.to_json(),
156
+ "status": "completed" if tool_output.success else "failed",
157
+ }
158
+ )
159
+
160
+ # retry if the tool output is empty or code fails
161
+ count = 0
162
+ tool_output_str = tool_output.text(include_results=False).strip()
163
+ while (
164
+ not tool_output.success
165
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
166
+ ) and count < max_retries:
167
+ prompt = TEST_PLANS.format(
168
+ docstring=tool_info,
169
+ plans=plan_str,
170
+ previous_attempts=PREVIOUS_FAILED.format(
171
+ code=code, error="\n".join(tool_output_str.splitlines()[-50:])
172
+ ),
173
+ media=media,
174
+ )
175
+ log_progress(
176
+ {
177
+ "type": "log",
178
+ "log_content": "Retrying code to test plans",
179
+ "status": "running",
180
+ "code": DefaultImports.prepend_imports(code),
181
+ }
182
+ )
183
+ code = extract_code(model(prompt, stream=False)) # type: ignore
184
+ tool_output = code_interpreter.exec_isolation(
185
+ DefaultImports.prepend_imports(code)
186
+ )
187
+ log_progress(
188
+ {
189
+ "type": "log",
190
+ "log_content": (
191
+ "Code execution succeeded"
192
+ if tool_output.success
193
+ else "Code execution failed"
194
+ ),
195
+ "code": DefaultImports.prepend_imports(code),
196
+ # "payload": tool_output.to_json(),
197
+ "status": "completed" if tool_output.success else "failed",
198
+ }
199
+ )
200
+ tool_output_str = tool_output.text(include_results=False).strip()
201
+
202
+ if verbosity == 2:
203
+ print_code("Code and test after attempted fix:", code)
204
+ _LOGGER.info(f"Code execution result after attempt {count + 1}")
205
+ _LOGGER.info(f"{tool_output_str}")
206
+
207
+ count += 1
208
+
209
+ return code, tool_output
210
+
211
+
212
+ def write_plan_thoughts(
213
+ chat: List[Message],
214
+ plans: Dict[str, Any],
215
+ tool_output_str: str,
216
+ model: LMM,
217
+ max_retries: int = 3,
218
+ ) -> Dict[str, str]:
219
+ user_req = chat[-1]["content"]
220
+ context = USER_REQ.format(user_request=user_req)
221
+ # because the tool picker model gets the image as well, we have to be careful with
222
+ # how much text we send it, so we truncate the tool output to 20,000 characters
223
+ prompt = PICK_PLAN.format(
224
+ context=context,
225
+ plans=format_plans(plans),
226
+ tool_output=tool_output_str[:20_000],
227
+ )
228
+ chat[-1]["content"] = prompt
229
+ count = 0
230
+
231
+ plan_thoughts = None
232
+ while plan_thoughts is None and count < max_retries:
233
+ try:
234
+ plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
235
+ except JSONDecodeError as e:
236
+ _LOGGER.exception(
237
+ f"Error while extracting JSON during picking best plan {str(e)}"
238
+ )
239
+ pass
240
+ count += 1
241
+
242
+ if (
243
+ plan_thoughts is None
244
+ or "best_plan" not in plan_thoughts
245
+ or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
246
+ ):
247
+ _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
248
+ plan_thoughts = {"best_plan": list(plans.keys())[0]}
249
+
250
+ if "thoughts" not in plan_thoughts:
251
+ plan_thoughts["thoughts"] = ""
252
+ return plan_thoughts
253
+
254
+
255
+ def pick_plan(
256
+ chat: List[Message],
257
+ plans: Dict[str, Any],
258
+ tool_info: str,
259
+ model: LMM,
260
+ code_interpreter: CodeInterpreter,
261
+ media: List[str],
262
+ log_progress: Callable[[Dict[str, Any]], None],
263
+ verbosity: int = 0,
264
+ max_retries: int = 3,
265
+ ) -> Tuple[Dict[str, str], str, Execution]:
266
+ log_progress(
267
+ {
268
+ "type": "log",
269
+ "log_content": "Generating code to pick the best plan",
270
+ "status": "started",
271
+ }
272
+ )
273
+
274
+ chat = copy.deepcopy(chat)
275
+ if chat[-1]["role"] != "user":
276
+ raise ValueError("Last chat message must be from the user.")
277
+
278
+ code, tool_output = write_and_exec_plan_tests(
279
+ plans,
280
+ tool_info,
281
+ media,
282
+ model,
283
+ log_progress,
284
+ code_interpreter,
285
+ verbosity,
286
+ max_retries,
287
+ )
288
+
289
+ if verbosity >= 1:
290
+ print_code("Final code:", code)
291
+
292
+ plan_thoughts = write_plan_thoughts(
293
+ chat,
294
+ plans,
295
+ tool_output.text(include_results=False).strip(),
296
+ model,
297
+ max_retries,
298
+ )
299
+
300
+ if verbosity >= 1:
301
+ _LOGGER.info(f"Best plan:\n{plan_thoughts}")
302
+ log_progress(
303
+ {
304
+ "type": "log",
305
+ "log_content": "Picked best plan",
306
+ "status": "completed",
307
+ "payload": plans[plan_thoughts["best_plan"]],
308
+ }
309
+ )
310
+ # return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str
311
+ return plan_thoughts, code, tool_output
312
+
313
+
314
+ class VisionAgentPlanner(Agent):
315
+ def __init__(
316
+ self,
317
+ planner: Optional[LMM] = None,
318
+ tool_recommender: Optional[Sim] = None,
319
+ verbosity: int = 0,
320
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
321
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
322
+ ) -> None:
323
+ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
324
+ self.verbosity = verbosity
325
+ if self.verbosity > 0:
326
+ _LOGGER.setLevel(logging.INFO)
327
+
328
+ self.tool_recommender = (
329
+ Sim(T.TOOLS_DF, sim_key="desc")
330
+ if tool_recommender is None
331
+ else tool_recommender
332
+ )
333
+ self.report_progress_callback = report_progress_callback
334
+ self.code_interpreter = code_interpreter
335
+
336
+ def __call__(
337
+ self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
338
+ ) -> str:
339
+ if isinstance(input, str):
340
+ input = [{"role": "user", "content": input}]
341
+ if media is not None:
342
+ input[0]["media"] = [media]
343
+ planning_context = self.generate_plan(input)
344
+ return str(planning_context.plans[planning_context.best_plan])
345
+
346
+ def generate_plan(
347
+ self,
348
+ chat: List[Message],
349
+ test_multi_plan: bool = True,
350
+ custom_tool_names: Optional[List[str]] = None,
351
+ code_interpreter: Optional[CodeInterpreter] = None,
352
+ ) -> PlanContext:
353
+ if not chat:
354
+ raise ValueError("Chat cannot be empty")
355
+
356
+ code_interpreter = (
357
+ code_interpreter
358
+ if code_interpreter is not None
359
+ else (
360
+ self.code_interpreter
361
+ if not isinstance(self.code_interpreter, str)
362
+ else CodeInterpreterFactory.new_instance(self.code_interpreter)
363
+ )
364
+ )
365
+ code_interpreter = cast(CodeInterpreter, code_interpreter)
366
+ with code_interpreter:
367
+ chat = copy.deepcopy(chat)
368
+ media_list = []
369
+ for chat_i in chat:
370
+ if "media" in chat_i:
371
+ for media in chat_i["media"]:
372
+ media = (
373
+ media
374
+ if type(media) is str
375
+ and media.startswith(("http", "https"))
376
+ else code_interpreter.upload_file(cast(str, media))
377
+ )
378
+ chat_i["content"] += f" Media name {media}" # type: ignore
379
+ media_list.append(str(media))
380
+
381
+ int_chat = cast(
382
+ List[Message],
383
+ [
384
+ (
385
+ {
386
+ "role": c["role"],
387
+ "content": c["content"],
388
+ "media": c["media"],
389
+ }
390
+ if "media" in c
391
+ else {"role": c["role"], "content": c["content"]}
392
+ )
393
+ for c in chat
394
+ ],
395
+ )
396
+
397
+ working_memory: List[Dict[str, str]] = []
398
+
399
+ plans = write_plans(
400
+ chat,
401
+ T.get_tool_descriptions_by_names(
402
+ custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
403
+ ),
404
+ format_memory(working_memory),
405
+ self.planner,
406
+ )
407
+
408
+ tool_docs = retrieve_tools(
409
+ plans,
410
+ self.tool_recommender,
411
+ self.log_progress,
412
+ self.verbosity,
413
+ )
414
+ if test_multi_plan:
415
+ plan_thoughts, code, tool_output = pick_plan(
416
+ int_chat,
417
+ plans,
418
+ tool_docs["all"],
419
+ self.planner,
420
+ code_interpreter,
421
+ media_list,
422
+ self.log_progress,
423
+ self.verbosity,
424
+ )
425
+ best_plan = plan_thoughts["best_plan"]
426
+ plan_thoughts_str = plan_thoughts["thoughts"]
427
+ tool_output_str = (
428
+ "```python\n"
429
+ + code
430
+ + "\n```\n"
431
+ + tool_output.text(include_results=False).strip()
432
+ )
433
+ else:
434
+ best_plan = list(plans.keys())[0]
435
+ tool_output_str = ""
436
+ plan_thoughts_str = ""
437
+ tool_output = None
438
+
439
+ if best_plan in plans and best_plan in tool_docs:
440
+ tool_doc = tool_docs[best_plan]
441
+ else:
442
+ if self.verbosity >= 1:
443
+ _LOGGER.warning(
444
+ f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
445
+ )
446
+ k = list(plans.keys())[0]
447
+ best_plan = k
448
+ tool_doc = tool_docs[k]
449
+
450
+ return PlanContext(
451
+ plans=plans,
452
+ best_plan=best_plan,
453
+ plan_thoughts=plan_thoughts_str,
454
+ tool_output=tool_output_str,
455
+ test_results=tool_output,
456
+ tool_doc=tool_doc,
457
+ )
458
+
459
+ def log_progress(self, log: Dict[str, Any]) -> None:
460
+ if self.report_progress_callback is not None:
461
+ self.report_progress_callback(log)
462
+
463
+
464
+ class AnthropicVisionAgentPlanner(VisionAgentPlanner):
465
+ def __init__(
466
+ self,
467
+ planner: Optional[LMM] = None,
468
+ tool_recommender: Optional[Sim] = None,
469
+ verbosity: int = 0,
470
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
471
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
472
+ ) -> None:
473
+ super().__init__(
474
+ planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
475
+ tool_recommender=tool_recommender,
476
+ verbosity=verbosity,
477
+ report_progress_callback=report_progress_callback,
478
+ code_interpreter=code_interpreter,
479
+ )
480
+
481
+
482
+ class OpenAIVisionAgentPlanner(VisionAgentPlanner):
483
+ def __init__(
484
+ self,
485
+ planner: Optional[LMM] = None,
486
+ tool_recommender: Optional[Sim] = None,
487
+ verbosity: int = 0,
488
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
489
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
490
+ ) -> None:
491
+ super().__init__(
492
+ planner=(
493
+ OpenAILMM(temperature=0.0, json_mode=True)
494
+ if planner is None
495
+ else planner
496
+ ),
497
+ tool_recommender=tool_recommender,
498
+ verbosity=verbosity,
499
+ report_progress_callback=report_progress_callback,
500
+ code_interpreter=code_interpreter,
501
+ )
502
+
503
+
504
+ class OllamaVisionAgentPlanner(VisionAgentPlanner):
505
+ def __init__(
506
+ self,
507
+ planner: Optional[LMM] = None,
508
+ tool_recommender: Optional[Sim] = None,
509
+ verbosity: int = 0,
510
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
511
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
512
+ ) -> None:
513
+ super().__init__(
514
+ planner=(
515
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
516
+ if planner is None
517
+ else planner
518
+ ),
519
+ tool_recommender=(
520
+ OllamaSim(T.TOOLS_DF, sim_key="desc")
521
+ if tool_recommender is None
522
+ else tool_recommender
523
+ ),
524
+ verbosity=verbosity,
525
+ report_progress_callback=report_progress_callback,
526
+ code_interpreter=code_interpreter,
527
+ )
528
+
529
+
530
+ class AzureVisionAgentPlanner(VisionAgentPlanner):
531
+ def __init__(
532
+ self,
533
+ planner: Optional[LMM] = None,
534
+ tool_recommender: Optional[Sim] = None,
535
+ verbosity: int = 0,
536
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
537
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
538
+ ) -> None:
539
+ super().__init__(
540
+ planner=(
541
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
542
+ if planner is None
543
+ else planner
544
+ ),
545
+ tool_recommender=(
546
+ AzureSim(T.TOOLS_DF, sim_key="desc")
547
+ if tool_recommender is None
548
+ else tool_recommender
549
+ ),
550
+ verbosity=verbosity,
551
+ report_progress_callback=report_progress_callback,
552
+ code_interpreter=code_interpreter,
553
+ )