vision-agent 0.2.161__py3-none-any.whl → 0.2.162__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,553 @@
1
+ import copy
2
+ import logging
3
+ from json import JSONDecodeError
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
+
7
+ from pydantic import BaseModel
8
+
9
+ import vision_agent.tools as T
10
+ from vision_agent.agent import Agent
11
+ from vision_agent.agent.agent_utils import (
12
+ DefaultImports,
13
+ extract_code,
14
+ extract_json,
15
+ format_memory,
16
+ format_plans,
17
+ print_code,
18
+ )
19
+ from vision_agent.agent.vision_agent_planner_prompts import (
20
+ PICK_PLAN,
21
+ PLAN,
22
+ PREVIOUS_FAILED,
23
+ TEST_PLANS,
24
+ USER_REQ,
25
+ )
26
+ from vision_agent.lmm import (
27
+ LMM,
28
+ AnthropicLMM,
29
+ AzureOpenAILMM,
30
+ Message,
31
+ OllamaLMM,
32
+ OpenAILMM,
33
+ )
34
+ from vision_agent.utils.execute import (
35
+ CodeInterpreter,
36
+ CodeInterpreterFactory,
37
+ Execution,
38
+ )
39
+ from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
40
+
41
+ _LOGGER = logging.getLogger(__name__)
42
+
43
+
44
+ class PlanContext(BaseModel):
45
+ plans: Dict[str, Dict[str, Union[str, List[str]]]]
46
+ best_plan: str
47
+ plan_thoughts: str
48
+ tool_output: str
49
+ tool_doc: str
50
+ test_results: Optional[Execution]
51
+
52
+
53
+ def retrieve_tools(
54
+ plans: Dict[str, Dict[str, Any]],
55
+ tool_recommender: Sim,
56
+ log_progress: Callable[[Dict[str, Any]], None],
57
+ verbosity: int = 0,
58
+ ) -> Dict[str, str]:
59
+ log_progress(
60
+ {
61
+ "type": "log",
62
+ "log_content": ("Retrieving tools for each plan"),
63
+ "status": "started",
64
+ }
65
+ )
66
+ tool_info = []
67
+ tool_desc = []
68
+ tool_lists: Dict[str, List[Dict[str, str]]] = {}
69
+ for k, plan in plans.items():
70
+ tool_lists[k] = []
71
+ for task in plan["instructions"]:
72
+ tools = tool_recommender.top_k(task, k=2, thresh=0.3)
73
+ tool_info.extend([e["doc"] for e in tools])
74
+ tool_desc.extend([e["desc"] for e in tools])
75
+ tool_lists[k].extend(
76
+ {"description": e["desc"], "documentation": e["doc"]} for e in tools
77
+ )
78
+
79
+ if verbosity == 2:
80
+ tool_desc_str = "\n".join(set(tool_desc))
81
+ _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
82
+
83
+ tool_lists_unique = {}
84
+ for k in tool_lists:
85
+ tool_lists_unique[k] = "\n\n".join(
86
+ set(e["documentation"] for e in tool_lists[k])
87
+ )
88
+ all_tools = "\n\n".join(set(tool_info))
89
+ tool_lists_unique["all"] = all_tools
90
+ return tool_lists_unique
91
+
92
+
93
+ def write_plans(
94
+ chat: List[Message], tool_desc: str, working_memory: str, model: LMM
95
+ ) -> Dict[str, Any]:
96
+ chat = copy.deepcopy(chat)
97
+ if chat[-1]["role"] != "user":
98
+ raise ValueError("Last message in chat must be from user")
99
+
100
+ user_request = chat[-1]["content"]
101
+ context = USER_REQ.format(user_request=user_request)
102
+ prompt = PLAN.format(
103
+ context=context,
104
+ tool_desc=tool_desc,
105
+ feedback=working_memory,
106
+ )
107
+ chat[-1]["content"] = prompt
108
+ return extract_json(model(chat, stream=False)) # type: ignore
109
+
110
+
111
+ def write_and_exec_plan_tests(
112
+ plans: Dict[str, Any],
113
+ tool_info: str,
114
+ media: List[str],
115
+ model: LMM,
116
+ log_progress: Callable[[Dict[str, Any]], None],
117
+ code_interpreter: CodeInterpreter,
118
+ verbosity: int = 0,
119
+ max_retries: int = 3,
120
+ ) -> Tuple[str, Execution]:
121
+
122
+ plan_str = format_plans(plans)
123
+ prompt = TEST_PLANS.format(
124
+ docstring=tool_info, plans=plan_str, previous_attempts="", media=media
125
+ )
126
+
127
+ code = extract_code(model(prompt, stream=False)) # type: ignore
128
+ log_progress(
129
+ {
130
+ "type": "log",
131
+ "log_content": "Executing code to test plans",
132
+ "code": DefaultImports.prepend_imports(code),
133
+ "status": "running",
134
+ }
135
+ )
136
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
137
+ # Because of the way we trace function calls the trace information ends up in the
138
+ # results. We don't want to show this info to the LLM so we don't include it in the
139
+ # tool_output_str.
140
+ tool_output_str = tool_output.text(include_results=False).strip()
141
+
142
+ if verbosity == 2:
143
+ print_code("Initial code and tests:", code)
144
+ _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
145
+
146
+ log_progress(
147
+ {
148
+ "type": "log",
149
+ "log_content": (
150
+ "Code execution succeeded"
151
+ if tool_output.success
152
+ else "Code execution failed"
153
+ ),
154
+ "code": DefaultImports.prepend_imports(code),
155
+ # "payload": tool_output.to_json(),
156
+ "status": "completed" if tool_output.success else "failed",
157
+ }
158
+ )
159
+
160
+ # retry if the tool output is empty or code fails
161
+ count = 0
162
+ tool_output_str = tool_output.text(include_results=False).strip()
163
+ while (
164
+ not tool_output.success
165
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
166
+ ) and count < max_retries:
167
+ prompt = TEST_PLANS.format(
168
+ docstring=tool_info,
169
+ plans=plan_str,
170
+ previous_attempts=PREVIOUS_FAILED.format(
171
+ code=code, error="\n".join(tool_output_str.splitlines()[-50:])
172
+ ),
173
+ media=media,
174
+ )
175
+ log_progress(
176
+ {
177
+ "type": "log",
178
+ "log_content": "Retrying code to test plans",
179
+ "status": "running",
180
+ "code": DefaultImports.prepend_imports(code),
181
+ }
182
+ )
183
+ code = extract_code(model(prompt, stream=False)) # type: ignore
184
+ tool_output = code_interpreter.exec_isolation(
185
+ DefaultImports.prepend_imports(code)
186
+ )
187
+ log_progress(
188
+ {
189
+ "type": "log",
190
+ "log_content": (
191
+ "Code execution succeeded"
192
+ if tool_output.success
193
+ else "Code execution failed"
194
+ ),
195
+ "code": DefaultImports.prepend_imports(code),
196
+ # "payload": tool_output.to_json(),
197
+ "status": "completed" if tool_output.success else "failed",
198
+ }
199
+ )
200
+ tool_output_str = tool_output.text(include_results=False).strip()
201
+
202
+ if verbosity == 2:
203
+ print_code("Code and test after attempted fix:", code)
204
+ _LOGGER.info(f"Code execution result after attempt {count + 1}")
205
+ _LOGGER.info(f"{tool_output_str}")
206
+
207
+ count += 1
208
+
209
+ return code, tool_output
210
+
211
+
212
+ def write_plan_thoughts(
213
+ chat: List[Message],
214
+ plans: Dict[str, Any],
215
+ tool_output_str: str,
216
+ model: LMM,
217
+ max_retries: int = 3,
218
+ ) -> Dict[str, str]:
219
+ user_req = chat[-1]["content"]
220
+ context = USER_REQ.format(user_request=user_req)
221
+ # because the tool picker model gets the image as well, we have to be careful with
222
+ # how much text we send it, so we truncate the tool output to 20,000 characters
223
+ prompt = PICK_PLAN.format(
224
+ context=context,
225
+ plans=format_plans(plans),
226
+ tool_output=tool_output_str[:20_000],
227
+ )
228
+ chat[-1]["content"] = prompt
229
+ count = 0
230
+
231
+ plan_thoughts = None
232
+ while plan_thoughts is None and count < max_retries:
233
+ try:
234
+ plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
235
+ except JSONDecodeError as e:
236
+ _LOGGER.exception(
237
+ f"Error while extracting JSON during picking best plan {str(e)}"
238
+ )
239
+ pass
240
+ count += 1
241
+
242
+ if (
243
+ plan_thoughts is None
244
+ or "best_plan" not in plan_thoughts
245
+ or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
246
+ ):
247
+ _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
248
+ plan_thoughts = {"best_plan": list(plans.keys())[0]}
249
+
250
+ if "thoughts" not in plan_thoughts:
251
+ plan_thoughts["thoughts"] = ""
252
+ return plan_thoughts
253
+
254
+
255
+ def pick_plan(
256
+ chat: List[Message],
257
+ plans: Dict[str, Any],
258
+ tool_info: str,
259
+ model: LMM,
260
+ code_interpreter: CodeInterpreter,
261
+ media: List[str],
262
+ log_progress: Callable[[Dict[str, Any]], None],
263
+ verbosity: int = 0,
264
+ max_retries: int = 3,
265
+ ) -> Tuple[Dict[str, str], str, Execution]:
266
+ log_progress(
267
+ {
268
+ "type": "log",
269
+ "log_content": "Generating code to pick the best plan",
270
+ "status": "started",
271
+ }
272
+ )
273
+
274
+ chat = copy.deepcopy(chat)
275
+ if chat[-1]["role"] != "user":
276
+ raise ValueError("Last chat message must be from the user.")
277
+
278
+ code, tool_output = write_and_exec_plan_tests(
279
+ plans,
280
+ tool_info,
281
+ media,
282
+ model,
283
+ log_progress,
284
+ code_interpreter,
285
+ verbosity,
286
+ max_retries,
287
+ )
288
+
289
+ if verbosity >= 1:
290
+ print_code("Final code:", code)
291
+
292
+ plan_thoughts = write_plan_thoughts(
293
+ chat,
294
+ plans,
295
+ tool_output.text(include_results=False).strip(),
296
+ model,
297
+ max_retries,
298
+ )
299
+
300
+ if verbosity >= 1:
301
+ _LOGGER.info(f"Best plan:\n{plan_thoughts}")
302
+ log_progress(
303
+ {
304
+ "type": "log",
305
+ "log_content": "Picked best plan",
306
+ "status": "completed",
307
+ "payload": plans[plan_thoughts["best_plan"]],
308
+ }
309
+ )
310
+ # return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str
311
+ return plan_thoughts, code, tool_output
312
+
313
+
314
+ class VisionAgentPlanner(Agent):
315
+ def __init__(
316
+ self,
317
+ planner: Optional[LMM] = None,
318
+ tool_recommender: Optional[Sim] = None,
319
+ verbosity: int = 0,
320
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
321
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
322
+ ) -> None:
323
+ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
324
+ self.verbosity = verbosity
325
+ if self.verbosity > 0:
326
+ _LOGGER.setLevel(logging.INFO)
327
+
328
+ self.tool_recommender = (
329
+ Sim(T.TOOLS_DF, sim_key="desc")
330
+ if tool_recommender is None
331
+ else tool_recommender
332
+ )
333
+ self.report_progress_callback = report_progress_callback
334
+ self.code_interpreter = code_interpreter
335
+
336
+ def __call__(
337
+ self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
338
+ ) -> str:
339
+ if isinstance(input, str):
340
+ input = [{"role": "user", "content": input}]
341
+ if media is not None:
342
+ input[0]["media"] = [media]
343
+ planning_context = self.generate_plan(input)
344
+ return str(planning_context.plans[planning_context.best_plan])
345
+
346
+ def generate_plan(
347
+ self,
348
+ chat: List[Message],
349
+ test_multi_plan: bool = True,
350
+ custom_tool_names: Optional[List[str]] = None,
351
+ code_interpreter: Optional[CodeInterpreter] = None,
352
+ ) -> PlanContext:
353
+ if not chat:
354
+ raise ValueError("Chat cannot be empty")
355
+
356
+ code_interpreter = (
357
+ code_interpreter
358
+ if code_interpreter is not None
359
+ else (
360
+ self.code_interpreter
361
+ if not isinstance(self.code_interpreter, str)
362
+ else CodeInterpreterFactory.new_instance(self.code_interpreter)
363
+ )
364
+ )
365
+ code_interpreter = cast(CodeInterpreter, code_interpreter)
366
+ with code_interpreter:
367
+ chat = copy.deepcopy(chat)
368
+ media_list = []
369
+ for chat_i in chat:
370
+ if "media" in chat_i:
371
+ for media in chat_i["media"]:
372
+ media = (
373
+ media
374
+ if type(media) is str
375
+ and media.startswith(("http", "https"))
376
+ else code_interpreter.upload_file(cast(str, media))
377
+ )
378
+ chat_i["content"] += f" Media name {media}" # type: ignore
379
+ media_list.append(str(media))
380
+
381
+ int_chat = cast(
382
+ List[Message],
383
+ [
384
+ (
385
+ {
386
+ "role": c["role"],
387
+ "content": c["content"],
388
+ "media": c["media"],
389
+ }
390
+ if "media" in c
391
+ else {"role": c["role"], "content": c["content"]}
392
+ )
393
+ for c in chat
394
+ ],
395
+ )
396
+
397
+ working_memory: List[Dict[str, str]] = []
398
+
399
+ plans = write_plans(
400
+ chat,
401
+ T.get_tool_descriptions_by_names(
402
+ custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
403
+ ),
404
+ format_memory(working_memory),
405
+ self.planner,
406
+ )
407
+
408
+ tool_docs = retrieve_tools(
409
+ plans,
410
+ self.tool_recommender,
411
+ self.log_progress,
412
+ self.verbosity,
413
+ )
414
+ if test_multi_plan:
415
+ plan_thoughts, code, tool_output = pick_plan(
416
+ int_chat,
417
+ plans,
418
+ tool_docs["all"],
419
+ self.planner,
420
+ code_interpreter,
421
+ media_list,
422
+ self.log_progress,
423
+ self.verbosity,
424
+ )
425
+ best_plan = plan_thoughts["best_plan"]
426
+ plan_thoughts_str = plan_thoughts["thoughts"]
427
+ tool_output_str = (
428
+ "```python\n"
429
+ + code
430
+ + "\n```\n"
431
+ + tool_output.text(include_results=False).strip()
432
+ )
433
+ else:
434
+ best_plan = list(plans.keys())[0]
435
+ tool_output_str = ""
436
+ plan_thoughts_str = ""
437
+ tool_output = None
438
+
439
+ if best_plan in plans and best_plan in tool_docs:
440
+ tool_doc = tool_docs[best_plan]
441
+ else:
442
+ if self.verbosity >= 1:
443
+ _LOGGER.warning(
444
+ f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
445
+ )
446
+ k = list(plans.keys())[0]
447
+ best_plan = k
448
+ tool_doc = tool_docs[k]
449
+
450
+ return PlanContext(
451
+ plans=plans,
452
+ best_plan=best_plan,
453
+ plan_thoughts=plan_thoughts_str,
454
+ tool_output=tool_output_str,
455
+ test_results=tool_output,
456
+ tool_doc=tool_doc,
457
+ )
458
+
459
+ def log_progress(self, log: Dict[str, Any]) -> None:
460
+ if self.report_progress_callback is not None:
461
+ self.report_progress_callback(log)
462
+
463
+
464
+ class AnthropicVisionAgentPlanner(VisionAgentPlanner):
465
+ def __init__(
466
+ self,
467
+ planner: Optional[LMM] = None,
468
+ tool_recommender: Optional[Sim] = None,
469
+ verbosity: int = 0,
470
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
471
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
472
+ ) -> None:
473
+ super().__init__(
474
+ planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
475
+ tool_recommender=tool_recommender,
476
+ verbosity=verbosity,
477
+ report_progress_callback=report_progress_callback,
478
+ code_interpreter=code_interpreter,
479
+ )
480
+
481
+
482
+ class OpenAIVisionAgentPlanner(VisionAgentPlanner):
483
+ def __init__(
484
+ self,
485
+ planner: Optional[LMM] = None,
486
+ tool_recommender: Optional[Sim] = None,
487
+ verbosity: int = 0,
488
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
489
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
490
+ ) -> None:
491
+ super().__init__(
492
+ planner=(
493
+ OpenAILMM(temperature=0.0, json_mode=True)
494
+ if planner is None
495
+ else planner
496
+ ),
497
+ tool_recommender=tool_recommender,
498
+ verbosity=verbosity,
499
+ report_progress_callback=report_progress_callback,
500
+ code_interpreter=code_interpreter,
501
+ )
502
+
503
+
504
+ class OllamaVisionAgentPlanner(VisionAgentPlanner):
505
+ def __init__(
506
+ self,
507
+ planner: Optional[LMM] = None,
508
+ tool_recommender: Optional[Sim] = None,
509
+ verbosity: int = 0,
510
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
511
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
512
+ ) -> None:
513
+ super().__init__(
514
+ planner=(
515
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
516
+ if planner is None
517
+ else planner
518
+ ),
519
+ tool_recommender=(
520
+ OllamaSim(T.TOOLS_DF, sim_key="desc")
521
+ if tool_recommender is None
522
+ else tool_recommender
523
+ ),
524
+ verbosity=verbosity,
525
+ report_progress_callback=report_progress_callback,
526
+ code_interpreter=code_interpreter,
527
+ )
528
+
529
+
530
+ class AzureVisionAgentPlanner(VisionAgentPlanner):
531
+ def __init__(
532
+ self,
533
+ planner: Optional[LMM] = None,
534
+ tool_recommender: Optional[Sim] = None,
535
+ verbosity: int = 0,
536
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
537
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
538
+ ) -> None:
539
+ super().__init__(
540
+ planner=(
541
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
542
+ if planner is None
543
+ else planner
544
+ ),
545
+ tool_recommender=(
546
+ AzureSim(T.TOOLS_DF, sim_key="desc")
547
+ if tool_recommender is None
548
+ else tool_recommender
549
+ ),
550
+ verbosity=verbosity,
551
+ report_progress_callback=report_progress_callback,
552
+ code_interpreter=code_interpreter,
553
+ )