vision-agent 1.0.4__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,564 +0,0 @@
1
- import copy
2
- import logging
3
- from json import JSONDecodeError
4
- from pathlib import Path
5
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
-
7
- from pydantic import BaseModel
8
- from tabulate import tabulate
9
-
10
- import vision_agent.tools as T
11
- from vision_agent.agent import Agent
12
- from vision_agent.agent.vision_agent_planner_prompts import (
13
- PICK_PLAN,
14
- PLAN,
15
- PREVIOUS_FAILED,
16
- TEST_PLANS,
17
- USER_REQ,
18
- )
19
- from vision_agent.lmm import LMM, AnthropicLMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
20
- from vision_agent.models import Message
21
- from vision_agent.sim import AzureSim, OllamaSim, Sim
22
- from vision_agent.utils.agent import (
23
- _MAX_TABULATE_COL_WIDTH,
24
- DefaultImports,
25
- extract_code,
26
- extract_json,
27
- format_feedback,
28
- format_plans,
29
- print_code,
30
- )
31
- from vision_agent.utils.execute import (
32
- CodeInterpreter,
33
- CodeInterpreterFactory,
34
- Execution,
35
- )
36
- from vision_agent.utils.tools_doc import get_tool_descriptions_by_names
37
-
38
- _LOGGER = logging.getLogger(__name__)
39
-
40
-
41
- class PlanContext(BaseModel):
42
- plans: Dict[str, Dict[str, Union[str, List[str]]]]
43
- best_plan: str
44
- plan_thoughts: str
45
- tool_output: str
46
- tool_doc: str
47
- test_results: Optional[Execution]
48
-
49
-
50
- def retrieve_tools(
51
- plans: Dict[str, Dict[str, Any]],
52
- tool_recommender: Sim,
53
- log_progress: Callable[[Dict[str, Any]], None],
54
- verbosity: int = 0,
55
- ) -> Dict[str, str]:
56
- log_progress(
57
- {
58
- "type": "log",
59
- "log_content": ("Retrieving tools for each plan"),
60
- "status": "started",
61
- }
62
- )
63
- tool_info = []
64
- tool_desc = []
65
- tool_lists: Dict[str, List[Dict[str, str]]] = {}
66
- for k, plan in plans.items():
67
- tool_lists[k] = []
68
- for task in plan["instructions"]:
69
- tools = tool_recommender.top_k(task, k=2, thresh=0.3)
70
- tool_info.extend([e["doc"] for e in tools])
71
- tool_desc.extend([e["desc"] for e in tools])
72
- tool_lists[k].extend(
73
- {"description": e["desc"], "documentation": e["doc"]} for e in tools
74
- )
75
-
76
- if verbosity == 2:
77
- tool_desc_str = "\n".join(set(tool_desc))
78
- _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
79
-
80
- tool_lists_unique = {}
81
- for k in tool_lists:
82
- tool_lists_unique[k] = "\n\n".join(
83
- set(e["documentation"] for e in tool_lists[k])
84
- )
85
- all_tools = "\n\n".join(set(tool_info))
86
- tool_lists_unique["all"] = all_tools
87
- return tool_lists_unique
88
-
89
-
90
- def _check_plan_format(plan: Dict[str, Any]) -> bool:
91
- if not isinstance(plan, dict):
92
- return False
93
-
94
- for k in plan:
95
- if "thoughts" not in plan[k] or "instructions" not in plan[k]:
96
- return False
97
- if not isinstance(plan[k]["instructions"], list):
98
- return False
99
- return True
100
-
101
-
102
- def write_plans(
103
- chat: List[Message], tool_desc: str, working_memory: str, model: LMM
104
- ) -> Dict[str, Any]:
105
- chat = copy.deepcopy(chat)
106
- if chat[-1]["role"] != "user":
107
- raise ValueError("Last message in chat must be from user")
108
-
109
- user_request = chat[-1]["content"]
110
- context = USER_REQ.format(user_request=user_request)
111
- prompt = PLAN.format(
112
- context=context,
113
- tool_desc=tool_desc,
114
- feedback=working_memory,
115
- )
116
- chat[-1]["content"] = prompt
117
- plans = extract_json(model(chat, stream=False)) # type: ignore
118
-
119
- count = 0
120
- while not _check_plan_format(plans) and count < 3:
121
- _LOGGER.info("Invalid plan format. Retrying.")
122
- plans = extract_json(model(chat, stream=False)) # type: ignore
123
- count += 1
124
- if count == 3:
125
- raise ValueError("Failed to generate valid plans after 3 attempts.")
126
- return plans
127
-
128
-
129
- def write_and_exec_plan_tests(
130
- plans: Dict[str, Any],
131
- tool_info: str,
132
- media: List[str],
133
- model: LMM,
134
- log_progress: Callable[[Dict[str, Any]], None],
135
- code_interpreter: CodeInterpreter,
136
- verbosity: int = 0,
137
- max_retries: int = 3,
138
- ) -> Tuple[str, Execution]:
139
-
140
- plan_str = format_plans(plans)
141
- prompt = TEST_PLANS.format(
142
- docstring=tool_info, plans=plan_str, previous_attempts="", media=media
143
- )
144
-
145
- code = extract_code(model(prompt, stream=False)) # type: ignore
146
- log_progress(
147
- {
148
- "type": "log",
149
- "log_content": "Executing code to test plans",
150
- "code": DefaultImports.prepend_imports(code),
151
- "status": "running",
152
- }
153
- )
154
- tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
155
- # Because of the way we trace function calls the trace information ends up in the
156
- # results. We don't want to show this info to the LLM so we don't include it in the
157
- # tool_output_str.
158
- tool_output_str = tool_output.text(include_results=False).strip()
159
-
160
- if verbosity == 2:
161
- print_code("Initial code and tests:", code)
162
- _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
163
-
164
- log_progress(
165
- {
166
- "type": "log",
167
- "log_content": (
168
- "Code execution succeeded"
169
- if tool_output.success
170
- else "Code execution failed"
171
- ),
172
- "code": DefaultImports.prepend_imports(code),
173
- # "payload": tool_output.to_json(),
174
- "status": "completed" if tool_output.success else "failed",
175
- }
176
- )
177
-
178
- # retry if the tool output is empty or code fails
179
- count = 0
180
- tool_output_str = tool_output.text(include_results=False).strip()
181
- while (
182
- not tool_output.success
183
- or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
184
- ) and count < max_retries:
185
- prompt = TEST_PLANS.format(
186
- docstring=tool_info,
187
- plans=plan_str,
188
- previous_attempts=PREVIOUS_FAILED.format(
189
- code=code, error="\n".join(tool_output_str.splitlines()[-50:])
190
- ),
191
- media=media,
192
- )
193
- log_progress(
194
- {
195
- "type": "log",
196
- "log_content": "Retrying code to test plans",
197
- "status": "running",
198
- "code": DefaultImports.prepend_imports(code),
199
- }
200
- )
201
- code = extract_code(model(prompt, stream=False)) # type: ignore
202
- tool_output = code_interpreter.exec_isolation(
203
- DefaultImports.prepend_imports(code)
204
- )
205
- log_progress(
206
- {
207
- "type": "log",
208
- "log_content": (
209
- "Code execution succeeded"
210
- if tool_output.success
211
- else "Code execution failed"
212
- ),
213
- "code": DefaultImports.prepend_imports(code),
214
- # "payload": tool_output.to_json(),
215
- "status": "completed" if tool_output.success else "failed",
216
- }
217
- )
218
- tool_output_str = tool_output.text(include_results=False).strip()
219
-
220
- if verbosity == 2:
221
- print_code("Code and test after attempted fix:", code)
222
- _LOGGER.info(f"Code execution result after attempt {count + 1}")
223
- _LOGGER.info(f"{tool_output_str}")
224
-
225
- count += 1
226
-
227
- return code, tool_output
228
-
229
-
230
- def write_plan_thoughts(
231
- chat: List[Message],
232
- plans: Dict[str, Any],
233
- tool_output_str: str,
234
- model: LMM,
235
- max_retries: int = 3,
236
- ) -> Dict[str, str]:
237
- user_req = chat[-1]["content"]
238
- context = USER_REQ.format(user_request=user_req)
239
- # because the tool picker model gets the image as well, we have to be careful with
240
- # how much text we send it, so we truncate the tool output to 20,000 characters
241
- prompt = PICK_PLAN.format(
242
- context=context,
243
- plans=format_plans(plans),
244
- tool_output=tool_output_str[:20_000],
245
- )
246
- chat[-1]["content"] = prompt
247
- count = 0
248
-
249
- plan_thoughts = None
250
- while plan_thoughts is None and count < max_retries:
251
- try:
252
- plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
253
- except JSONDecodeError as e:
254
- _LOGGER.exception(
255
- f"Error while extracting JSON during picking best plan {str(e)}"
256
- )
257
- pass
258
- count += 1
259
-
260
- if (
261
- plan_thoughts is None
262
- or "best_plan" not in plan_thoughts
263
- or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
264
- ):
265
- _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
266
- plan_thoughts = {"best_plan": list(plans.keys())[0]}
267
-
268
- if "thoughts" not in plan_thoughts:
269
- plan_thoughts["thoughts"] = ""
270
- return plan_thoughts
271
-
272
-
273
- def pick_plan(
274
- chat: List[Message],
275
- plans: Dict[str, Any],
276
- tool_info: str,
277
- model: LMM,
278
- code_interpreter: CodeInterpreter,
279
- media: List[str],
280
- log_progress: Callable[[Dict[str, Any]], None],
281
- verbosity: int = 0,
282
- max_retries: int = 3,
283
- ) -> Tuple[Dict[str, str], str, Execution]:
284
- log_progress(
285
- {
286
- "type": "log",
287
- "log_content": "Generating code to pick the best plan",
288
- "status": "started",
289
- }
290
- )
291
-
292
- chat = copy.deepcopy(chat)
293
- if chat[-1]["role"] != "user":
294
- raise ValueError("Last chat message must be from the user.")
295
-
296
- code, tool_output = write_and_exec_plan_tests(
297
- plans,
298
- tool_info,
299
- media,
300
- model,
301
- log_progress,
302
- code_interpreter,
303
- verbosity,
304
- max_retries,
305
- )
306
-
307
- if verbosity >= 1:
308
- print_code("Final code:", code)
309
-
310
- plan_thoughts = write_plan_thoughts(
311
- chat,
312
- plans,
313
- tool_output.text(include_results=False).strip(),
314
- model,
315
- max_retries,
316
- )
317
-
318
- if verbosity >= 1:
319
- _LOGGER.info(f"Best plan:\n{plan_thoughts}")
320
- log_progress(
321
- {
322
- "type": "log",
323
- "log_content": "Picked best plan",
324
- "status": "completed",
325
- "payload": plans[plan_thoughts["best_plan"]],
326
- }
327
- )
328
- return plan_thoughts, code, tool_output
329
-
330
-
331
- class VisionAgentPlanner(Agent):
332
- def __init__(
333
- self,
334
- planner: Optional[LMM] = None,
335
- tool_recommender: Optional[Sim] = None,
336
- verbosity: int = 0,
337
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
338
- code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
339
- ) -> None:
340
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
341
- self.verbosity = verbosity
342
- if self.verbosity > 0:
343
- _LOGGER.setLevel(logging.INFO)
344
-
345
- self.tool_recommender = (
346
- Sim(T.get_tools_df(), sim_key="desc")
347
- if tool_recommender is None
348
- else tool_recommender
349
- )
350
- self.report_progress_callback = report_progress_callback
351
- self.code_interpreter = code_interpreter
352
-
353
- def __call__(
354
- self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
355
- ) -> str:
356
- if isinstance(input, str):
357
- input = [{"role": "user", "content": input}]
358
- if media is not None:
359
- input[0]["media"] = [media]
360
- planning_context = self.generate_plan(input)
361
- return str(planning_context.plans[planning_context.best_plan])
362
-
363
- def generate_plan(
364
- self,
365
- chat: List[Message],
366
- test_multi_plan: bool = True,
367
- custom_tool_names: Optional[List[str]] = None,
368
- code_interpreter: Optional[CodeInterpreter] = None,
369
- ) -> PlanContext:
370
- if not chat:
371
- raise ValueError("Chat cannot be empty")
372
-
373
- code_interpreter = (
374
- code_interpreter
375
- if code_interpreter is not None
376
- else (
377
- self.code_interpreter
378
- if not isinstance(self.code_interpreter, str)
379
- else CodeInterpreterFactory.new_instance(self.code_interpreter)
380
- )
381
- )
382
- code_interpreter = cast(CodeInterpreter, code_interpreter)
383
- with code_interpreter:
384
- chat = copy.deepcopy(chat)
385
- media_list = []
386
- for chat_i in chat:
387
- if "media" in chat_i:
388
- for media in chat_i["media"]:
389
- chat_i["content"] += f" Media name {media}" # type: ignore
390
- media_list.append(str(media))
391
-
392
- int_chat = cast(
393
- List[Message],
394
- [
395
- (
396
- {
397
- "role": c["role"],
398
- "content": c["content"],
399
- "media": c["media"],
400
- }
401
- if "media" in c
402
- else {"role": c["role"], "content": c["content"]}
403
- )
404
- for c in chat
405
- ],
406
- )
407
-
408
- working_memory: List[Dict[str, str]] = []
409
-
410
- plans = write_plans(
411
- chat,
412
- get_tool_descriptions_by_names(
413
- custom_tool_names, T.tools.FUNCTION_TOOLS, T.tools.UTIL_TOOLS # type: ignore
414
- ),
415
- format_feedback(working_memory),
416
- self.planner,
417
- )
418
- if self.verbosity >= 1:
419
- for plan in plans:
420
- plan_fixed = [
421
- {"instructions": e} for e in plans[plan]["instructions"]
422
- ]
423
- _LOGGER.info(
424
- f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
425
- )
426
-
427
- tool_docs = retrieve_tools(
428
- plans,
429
- self.tool_recommender,
430
- self.log_progress,
431
- self.verbosity,
432
- )
433
- if test_multi_plan:
434
- plan_thoughts, code, tool_output = pick_plan(
435
- int_chat,
436
- plans,
437
- tool_docs["all"],
438
- self.planner,
439
- code_interpreter,
440
- media_list,
441
- self.log_progress,
442
- self.verbosity,
443
- )
444
- best_plan = plan_thoughts["best_plan"]
445
- plan_thoughts_str = plan_thoughts["thoughts"]
446
- tool_output_str = (
447
- "```python\n"
448
- + code
449
- + "\n```\n"
450
- + tool_output.text(include_results=False).strip()
451
- )
452
- else:
453
- best_plan = list(plans.keys())[0]
454
- tool_output_str = ""
455
- plan_thoughts_str = ""
456
- tool_output = None
457
-
458
- if best_plan in plans and best_plan in tool_docs:
459
- tool_doc = tool_docs[best_plan]
460
- else:
461
- if self.verbosity >= 1:
462
- _LOGGER.warning(
463
- f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
464
- )
465
- k = list(plans.keys())[0]
466
- best_plan = k
467
- tool_doc = tool_docs[k]
468
-
469
- return PlanContext(
470
- plans=plans,
471
- best_plan=best_plan,
472
- plan_thoughts=plan_thoughts_str,
473
- tool_output=tool_output_str,
474
- test_results=tool_output,
475
- tool_doc=tool_doc,
476
- )
477
-
478
- def log_progress(self, log: Dict[str, Any]) -> None:
479
- if self.report_progress_callback is not None:
480
- self.report_progress_callback(log)
481
-
482
-
483
- class AnthropicVisionAgentPlanner(VisionAgentPlanner):
484
- def __init__(
485
- self,
486
- planner: Optional[LMM] = None,
487
- tool_recommender: Optional[Sim] = None,
488
- verbosity: int = 0,
489
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
490
- code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
491
- ) -> None:
492
- super().__init__(
493
- planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
494
- tool_recommender=tool_recommender,
495
- verbosity=verbosity,
496
- report_progress_callback=report_progress_callback,
497
- code_interpreter=code_interpreter,
498
- )
499
-
500
-
501
- class OpenAIVisionAgentPlanner(VisionAgentPlanner):
502
- def __init__(
503
- self,
504
- planner: Optional[LMM] = None,
505
- tool_recommender: Optional[Sim] = None,
506
- verbosity: int = 0,
507
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
508
- code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
509
- ) -> None:
510
- super().__init__(
511
- planner=(OpenAILMM(temperature=0.0) if planner is None else planner),
512
- tool_recommender=tool_recommender,
513
- verbosity=verbosity,
514
- report_progress_callback=report_progress_callback,
515
- code_interpreter=code_interpreter,
516
- )
517
-
518
-
519
- class OllamaVisionAgentPlanner(VisionAgentPlanner):
520
- def __init__(
521
- self,
522
- planner: Optional[LMM] = None,
523
- tool_recommender: Optional[Sim] = None,
524
- verbosity: int = 0,
525
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
526
- code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
527
- ) -> None:
528
- super().__init__(
529
- planner=(
530
- OllamaLMM(model_name="llama3.2-vision", temperature=0.0)
531
- if planner is None
532
- else planner
533
- ),
534
- tool_recommender=(
535
- OllamaSim(T.get_tools_df(), sim_key="desc")
536
- if tool_recommender is None
537
- else tool_recommender
538
- ),
539
- verbosity=verbosity,
540
- report_progress_callback=report_progress_callback,
541
- code_interpreter=code_interpreter,
542
- )
543
-
544
-
545
- class AzureVisionAgentPlanner(VisionAgentPlanner):
546
- def __init__(
547
- self,
548
- planner: Optional[LMM] = None,
549
- tool_recommender: Optional[Sim] = None,
550
- verbosity: int = 0,
551
- report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
552
- code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
553
- ) -> None:
554
- super().__init__(
555
- planner=(AzureOpenAILMM(temperature=0.0) if planner is None else planner),
556
- tool_recommender=(
557
- AzureSim(T.get_tools_df(), sim_key="desc")
558
- if tool_recommender is None
559
- else tool_recommender
560
- ),
561
- verbosity=verbosity,
562
- report_progress_callback=report_progress_callback,
563
- code_interpreter=code_interpreter,
564
- )