vision-agent 0.2.56__py3-none-any.whl → 0.2.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vision_agent/__init__.py CHANGED
@@ -1,3 +1,2 @@
1
1
  from .agent import Agent
2
- from .llm import LLM, OpenAILLM
3
- from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
2
+ from .lmm import LMM, OpenAILMM
@@ -2,12 +2,14 @@ from abc import ABC, abstractmethod
2
2
  from pathlib import Path
3
3
  from typing import Any, Dict, List, Optional, Union
4
4
 
5
+ from vision_agent.lmm import Message
6
+
5
7
 
6
8
  class Agent(ABC):
7
9
  @abstractmethod
8
10
  def __call__(
9
11
  self,
10
- input: Union[List[Dict[str, str]], str],
12
+ input: Union[str, List[Message]],
11
13
  media: Optional[Union[str, Path]] = None,
12
14
  ) -> str:
13
15
  pass
@@ -13,7 +13,6 @@ from rich.style import Style
13
13
  from rich.syntax import Syntax
14
14
  from tabulate import tabulate
15
15
 
16
- from vision_agent.llm.llm import AzureOpenAILLM
17
16
  import vision_agent.tools as T
18
17
  from vision_agent.agent import Agent
19
18
  from vision_agent.agent.vision_agent_prompts import (
@@ -25,8 +24,7 @@ from vision_agent.agent.vision_agent_prompts import (
25
24
  SIMPLE_TEST,
26
25
  USER_REQ,
27
26
  )
28
- from vision_agent.llm import LLM, OpenAILLM
29
- from vision_agent.lmm import LMM, OpenAILMM
27
+ from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
30
28
  from vision_agent.utils import CodeInterpreterFactory, Execution
31
29
  from vision_agent.utils.execute import CodeInterpreter
32
30
  from vision_agent.utils.image_utils import b64_to_pil
@@ -133,11 +131,10 @@ def extract_image(
133
131
 
134
132
 
135
133
  def write_plan(
136
- chat: List[Dict[str, str]],
134
+ chat: List[Message],
137
135
  tool_desc: str,
138
136
  working_memory: str,
139
- model: Union[LLM, LMM],
140
- media: Optional[Sequence[Union[str, Path]]] = None,
137
+ model: LMM,
141
138
  ) -> List[Dict[str, str]]:
142
139
  chat = copy.deepcopy(chat)
143
140
  if chat[-1]["role"] != "user":
@@ -147,18 +144,58 @@ def write_plan(
147
144
  context = USER_REQ.format(user_request=user_request)
148
145
  prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
149
146
  chat[-1]["content"] = prompt
150
- if isinstance(model, OpenAILMM):
151
- media = extract_image(media)
152
- return extract_json(model.chat(chat, images=media))["plan"] # type: ignore
153
- else:
154
- return extract_json(model.chat(chat))["plan"] # type: ignore
147
+ return extract_json(model.chat(chat))["plan"] # type: ignore
148
+
149
+
150
+ def write_code(
151
+ coder: LMM,
152
+ chat: List[Message],
153
+ tool_info: str,
154
+ feedback: str,
155
+ ) -> str:
156
+ chat = copy.deepcopy(chat)
157
+ if chat[-1]["role"] != "user":
158
+ raise ValueError("Last chat message must be from the user.")
159
+
160
+ user_request = chat[-1]["content"]
161
+ prompt = CODE.format(
162
+ docstring=tool_info,
163
+ question=user_request,
164
+ feedback=feedback,
165
+ )
166
+ chat[-1]["content"] = prompt
167
+ return extract_code(coder(chat))
168
+
169
+
170
+ def write_test(
171
+ tester: LMM,
172
+ chat: List[Message],
173
+ tool_utils: str,
174
+ code: str,
175
+ feedback: str,
176
+ media: Optional[Sequence[Union[str, Path]]] = None,
177
+ ) -> str:
178
+ chat = copy.deepcopy(chat)
179
+ if chat[-1]["role"] != "user":
180
+ raise ValueError("Last chat message must be from the user.")
181
+
182
+ user_request = chat[-1]["content"]
183
+ prompt = SIMPLE_TEST.format(
184
+ docstring=tool_utils,
185
+ question=user_request,
186
+ code=code,
187
+ feedback=feedback,
188
+ media=media,
189
+ )
190
+ chat[-1]["content"] = prompt
191
+ return extract_code(tester(chat))
155
192
 
156
193
 
157
194
  def reflect(
158
- chat: List[Dict[str, str]],
195
+ chat: List[Message],
159
196
  plan: str,
160
197
  code: str,
161
- model: Union[LLM, LMM],
198
+ model: LMM,
162
199
  ) -> Dict[str, Union[str, bool]]:
163
200
  chat = copy.deepcopy(chat)
164
201
  if chat[-1]["role"] != "user":
@@ -168,22 +205,22 @@ def reflect(
168
205
  context = USER_REQ.format(user_request=user_request)
169
206
  prompt = REFLECT.format(context=context, plan=plan, code=code)
170
207
  chat[-1]["content"] = prompt
171
- return extract_json(model.chat(chat))
208
+ return extract_json(model(chat))
172
209
 
173
210
 
174
211
  def write_and_test_code(
175
- task: str,
212
+ chat: List[Message],
176
213
  tool_info: str,
177
214
  tool_utils: str,
178
215
  working_memory: List[Dict[str, str]],
179
- coder: LLM,
180
- tester: LLM,
181
- debugger: LLM,
216
+ coder: LMM,
217
+ tester: LMM,
218
+ debugger: LMM,
182
219
  code_interpreter: CodeInterpreter,
183
220
  log_progress: Callable[[Dict[str, Any]], None],
184
221
  verbosity: int = 0,
185
222
  max_retries: int = 3,
186
- input_media: Optional[Union[str, Path]] = None,
223
+ media: Optional[Sequence[Union[str, Path]]] = None,
187
224
  ) -> Dict[str, Any]:
188
225
  log_progress(
189
226
  {
@@ -191,25 +228,9 @@ def write_and_test_code(
191
228
  "status": "started",
192
229
  }
193
230
  )
194
- code = extract_code(
195
- coder(
196
- CODE.format(
197
- docstring=tool_info,
198
- question=task,
199
- feedback=format_memory(working_memory),
200
- )
201
- )
202
- )
203
- test = extract_code(
204
- tester(
205
- SIMPLE_TEST.format(
206
- docstring=tool_utils,
207
- question=task,
208
- code=code,
209
- feedback=working_memory,
210
- media=input_media,
211
- )
212
- )
231
+ code = write_code(coder, chat, tool_info, format_memory(working_memory))
232
+ test = write_test(
233
+ tester, chat, tool_utils, code, format_memory(working_memory), media
213
234
  )
214
235
 
215
236
  log_progress(
@@ -392,10 +413,10 @@ class VisionAgent(Agent):
392
413
 
393
414
  def __init__(
394
415
  self,
395
- planner: Optional[Union[LLM, LMM]] = None,
396
- coder: Optional[LLM] = None,
397
- tester: Optional[LLM] = None,
398
- debugger: Optional[LLM] = None,
416
+ planner: Optional[LMM] = None,
417
+ coder: Optional[LMM] = None,
418
+ tester: Optional[LMM] = None,
419
+ debugger: Optional[LMM] = None,
399
420
  tool_recommender: Optional[Sim] = None,
400
421
  verbosity: int = 0,
401
422
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
@@ -403,10 +424,10 @@ class VisionAgent(Agent):
403
424
  """Initialize the Vision Agent.
404
425
 
405
426
  Parameters:
406
- planner (Optional[LLM]): The planner model to use. Defaults to OpenAILLM.
407
- coder (Optional[LLM]): The coder model to use. Defaults to OpenAILLM.
408
- tester (Optional[LLM]): The tester model to use. Defaults to OpenAILLM.
409
- debugger (Optional[LLM]): The debugger model to
427
+ planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
428
+ coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
429
+ tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
430
+ debugger (Optional[LMM]): The debugger model to
410
431
  tool_recommender (Optional[Sim]): The tool recommender model to use.
411
432
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
412
433
  highest verbosity level which will output all intermediate debugging
@@ -418,12 +439,12 @@ class VisionAgent(Agent):
418
439
  """
419
440
 
420
441
  self.planner = (
421
- OpenAILLM(temperature=0.0, json_mode=True) if planner is None else planner
442
+ OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
422
443
  )
423
- self.coder = OpenAILLM(temperature=0.0) if coder is None else coder
424
- self.tester = OpenAILLM(temperature=0.0) if tester is None else tester
444
+ self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
445
+ self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
425
446
  self.debugger = (
426
- OpenAILLM(temperature=0.0, json_mode=True) if debugger is None else debugger
447
+ OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
427
448
  )
428
449
 
429
450
  self.tool_recommender = (
@@ -437,7 +458,7 @@ class VisionAgent(Agent):
437
458
 
438
459
  def __call__(
439
460
  self,
440
- input: Union[List[Dict[str, str]], str],
461
+ input: Union[str, List[Message]],
441
462
  media: Optional[Union[str, Path]] = None,
442
463
  ) -> str:
443
464
  """Chat with Vision Agent and return intermediate information regarding the task.
@@ -454,23 +475,26 @@ class VisionAgent(Agent):
454
475
 
455
476
  if isinstance(input, str):
456
477
  input = [{"role": "user", "content": input}]
457
- results = self.chat_with_workflow(input, media)
478
+ if media is not None:
479
+ input[0]["media"] = [media]
480
+ results = self.chat_with_workflow(input)
458
481
  results.pop("working_memory")
459
482
  return results # type: ignore
460
483
 
461
484
  def chat_with_workflow(
462
485
  self,
463
- chat: List[Dict[str, str]],
464
- media: Optional[Union[str, Path]] = None,
486
+ chat: List[Message],
465
487
  self_reflection: bool = False,
466
488
  display_visualization: bool = False,
467
489
  ) -> Dict[str, Any]:
468
490
  """Chat with Vision Agent and return intermediate information regarding the task.
469
491
 
470
492
  Parameters:
471
- chat (List[Dict[str, str]]): A conversation in the format of
472
- [{"role": "user", "content": "describe your task here..."}].
473
- media (Optional[Union[str, Path]]): The media file to be used in the task.
493
+ chat (List[MediaChatItem]): A conversation
494
+ in the format of:
495
+ [{"role": "user", "content": "describe your task here..."}]
496
+ or if it contains media files, it should be in the format of:
497
+ [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
474
498
  self_reflection (bool): Whether to reflect on the task and debug the code.
475
499
  display_visualization (bool): If True, it opens a new window locally to
476
500
  show the image(s) created by visualization code (if there is any).
@@ -485,11 +509,19 @@ class VisionAgent(Agent):
485
509
 
486
510
  # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
487
511
  with CodeInterpreterFactory.new_instance() as code_interpreter:
488
- if media is not None:
489
- media = code_interpreter.upload_file(media)
490
- for chat_i in chat:
491
- if chat_i["role"] == "user":
492
- chat_i["content"] += f" Image name {media}"
512
+ chat = copy.deepcopy(chat)
513
+ media_list = []
514
+ for chat_i in chat:
515
+ if "media" in chat_i:
516
+ for media in chat_i["media"]:
517
+ media = code_interpreter.upload_file(media)
518
+ chat_i["content"] += f" Media name {media}" # type: ignore
519
+ media_list.append(media)
520
+
521
+ int_chat = cast(
522
+ List[Message],
523
+ [{"role": c["role"], "content": c["content"]} for c in chat],
524
+ )
493
525
 
494
526
  code = ""
495
527
  test = ""
@@ -507,11 +539,10 @@ class VisionAgent(Agent):
507
539
  }
508
540
  )
509
541
  plan_i = write_plan(
510
- chat,
542
+ int_chat,
511
543
  T.TOOL_DESCRIPTIONS,
512
544
  format_memory(working_memory),
513
545
  self.planner,
514
- media=[media] if media else None,
515
546
  )
516
547
  plan_i_str = "\n-".join([e["instructions"] for e in plan_i])
517
548
 
@@ -534,9 +565,7 @@ class VisionAgent(Agent):
534
565
  self.verbosity,
535
566
  )
536
567
  results = write_and_test_code(
537
- task=FULL_TASK.format(
538
- user_request=chat[0]["content"], subtasks=plan_i_str
539
- ),
568
+ chat=int_chat,
540
569
  tool_info=tool_info,
541
570
  tool_utils=T.UTILITIES_DOCSTRING,
542
571
  working_memory=working_memory,
@@ -546,7 +575,7 @@ class VisionAgent(Agent):
546
575
  code_interpreter=code_interpreter,
547
576
  log_progress=self.log_progress,
548
577
  verbosity=self.verbosity,
549
- input_media=media,
578
+ media=media_list,
550
579
  )
551
580
  success = cast(bool, results["success"])
552
581
  code = cast(str, results["code"])
@@ -564,7 +593,7 @@ class VisionAgent(Agent):
564
593
  }
565
594
  )
566
595
  reflection = reflect(
567
- chat,
596
+ int_chat,
568
597
  FULL_TASK.format(
569
598
  user_request=chat[0]["content"], subtasks=plan_i_str
570
599
  ),
@@ -634,10 +663,10 @@ class AzureVisionAgent(VisionAgent):
634
663
 
635
664
  def __init__(
636
665
  self,
637
- planner: Optional[Union[LLM, LMM]] = None,
638
- coder: Optional[LLM] = None,
639
- tester: Optional[LLM] = None,
640
- debugger: Optional[LLM] = None,
666
+ planner: Optional[LMM] = None,
667
+ coder: Optional[LMM] = None,
668
+ tester: Optional[LMM] = None,
669
+ debugger: Optional[LMM] = None,
641
670
  tool_recommender: Optional[Sim] = None,
642
671
  verbosity: int = 0,
643
672
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
@@ -645,10 +674,10 @@ class AzureVisionAgent(VisionAgent):
645
674
  """Initialize the Vision Agent.
646
675
 
647
676
  Parameters:
648
- planner (Optional[LLM]): The planner model to use. Defaults to OpenAILLM.
649
- coder (Optional[LLM]): The coder model to use. Defaults to OpenAILLM.
650
- tester (Optional[LLM]): The tester model to use. Defaults to OpenAILLM.
651
- debugger (Optional[LLM]): The debugger model to
677
+ planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
678
+ coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
679
+ tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
680
+ debugger (Optional[LMM]): The debugger model to
652
681
  tool_recommender (Optional[Sim]): The tool recommender model to use.
653
682
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
654
683
  highest verbosity level which will output all intermediate debugging
@@ -660,14 +689,14 @@ class AzureVisionAgent(VisionAgent):
660
689
  """
661
690
  super().__init__(
662
691
  planner=(
663
- AzureOpenAILLM(temperature=0.0, json_mode=True)
692
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
664
693
  if planner is None
665
694
  else planner
666
695
  ),
667
- coder=AzureOpenAILLM(temperature=0.0) if coder is None else coder,
668
- tester=AzureOpenAILLM(temperature=0.0) if tester is None else tester,
696
+ coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
697
+ tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
669
698
  debugger=(
670
- AzureOpenAILLM(temperature=0.0, json_mode=True)
699
+ AzureOpenAILMM(temperature=0.0, json_mode=True)
671
700
  if debugger is None
672
701
  else debugger
673
702
  ),
@@ -171,7 +171,7 @@ This is the documentation for the functions you have access to. You may call any
171
171
  **Instructions**:
172
172
  1. Verify the fundamental functionality under normal conditions.
173
173
  2. Ensure each test case is well-documented with comments explaining the scenario it covers.
174
- 3. Your test case MUST run only on the given image which is {media}
174
+ 3. Your test case MUST run only on the given images which are {media}
175
175
  4. Your test case MUST run only with the given values which is available in the question - {question}
176
176
  5. DO NOT use any non-existent or dummy image or video files that are not provided by the user's instructions.
177
177
  6. DO NOT mock any functions, you must test their functionality as is.
@@ -1 +1 @@
1
- from .lmm import LMM, AzureOpenAILMM, LLaVALMM, OpenAILMM, get_lmm
1
+ from .lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
vision_agent/lmm/lmm.py CHANGED
@@ -6,15 +6,13 @@ from abc import ABC, abstractmethod
6
6
  from pathlib import Path
7
7
  from typing import Any, Callable, Dict, List, Optional, Union, cast
8
8
 
9
- import requests
10
9
  from openai import AzureOpenAI, OpenAI
11
10
 
11
+ import vision_agent.tools as T
12
12
  from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
13
13
 
14
14
  _LOGGER = logging.getLogger(__name__)
15
15
 
16
- _LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws"
17
-
18
16
 
19
17
  def encode_image(image: Union[str, Path]) -> str:
20
18
  with open(image, "rb") as f:
@@ -22,84 +20,38 @@ def encode_image(image: Union[str, Path]) -> str:
22
20
  return encoded_image
23
21
 
24
22
 
23
+ TextOrImage = Union[str, List[Union[str, Path]]]
24
+ Message = Dict[str, TextOrImage]
25
+
26
+
25
27
  class LMM(ABC):
26
28
  @abstractmethod
27
29
  def generate(
28
- self, prompt: str, images: Optional[List[Union[str, Path]]] = None
30
+ self, prompt: str, media: Optional[List[Union[str, Path]]] = None
29
31
  ) -> str:
30
32
  pass
31
33
 
32
34
  @abstractmethod
33
35
  def chat(
34
36
  self,
35
- chat: List[Dict[str, str]],
36
- images: Optional[List[Union[str, Path]]] = None,
37
+ chat: List[Message],
37
38
  ) -> str:
38
39
  pass
39
40
 
40
41
  @abstractmethod
41
42
  def __call__(
42
43
  self,
43
- input: Union[str, List[Dict[str, str]]],
44
- images: Optional[List[Union[str, Path]]] = None,
44
+ input: Union[str, List[Message]],
45
45
  ) -> str:
46
46
  pass
47
47
 
48
48
 
49
- class LLaVALMM(LMM):
50
- r"""An LMM class for the LLaVA-1.6 34B model."""
51
-
52
- def __init__(self, model_name: str):
53
- self.model_name = model_name
54
-
55
- def __call__(
56
- self,
57
- input: Union[str, List[Dict[str, str]]],
58
- images: Optional[List[Union[str, Path]]] = None,
59
- ) -> str:
60
- if isinstance(input, str):
61
- return self.generate(input, images)
62
- return self.chat(input, images)
63
-
64
- def chat(
65
- self,
66
- chat: List[Dict[str, str]],
67
- images: Optional[List[Union[str, Path]]] = None,
68
- ) -> str:
69
- raise NotImplementedError("Chat not supported for LLaVA")
70
-
71
- def generate(
72
- self,
73
- prompt: str,
74
- images: Optional[List[Union[str, Path]]] = None,
75
- temperature: float = 0.1,
76
- max_new_tokens: int = 1500,
77
- ) -> str:
78
- data = {"prompt": prompt}
79
- if images and len(images) > 0:
80
- data["image"] = encode_image(images[0])
81
- data["temperature"] = temperature # type: ignore
82
- data["max_new_tokens"] = max_new_tokens # type: ignore
83
- res = requests.post(
84
- _LLAVA_ENDPOINT,
85
- headers={"Content-Type": "application/json"},
86
- json=data,
87
- )
88
- resp_json: Dict[str, Any] = res.json()
89
- if (
90
- "statusCode" in resp_json and resp_json["statusCode"] != 200
91
- ) or "statusCode" not in resp_json:
92
- _LOGGER.error(f"Request failed: {resp_json}")
93
- raise ValueError(f"Request failed: {resp_json}")
94
- return cast(str, resp_json["data"])
95
-
96
-
97
49
  class OpenAILMM(LMM):
98
50
  r"""An LMM class for the OpenAI GPT-4 Vision model."""
99
51
 
100
52
  def __init__(
101
53
  self,
102
- model_name: str = "gpt-4-turbo",
54
+ model_name: str = "gpt-4o",
103
55
  api_key: Optional[str] = None,
104
56
  max_tokens: int = 1024,
105
57
  json_mode: bool = False,
@@ -120,44 +72,49 @@ class OpenAILMM(LMM):
120
72
 
121
73
  def __call__(
122
74
  self,
123
- input: Union[str, List[Dict[str, str]]],
124
- images: Optional[List[Union[str, Path]]] = None,
75
+ input: Union[str, List[Message]],
125
76
  ) -> str:
126
77
  if isinstance(input, str):
127
- return self.generate(input, images)
128
- return self.chat(input, images)
78
+ return self.generate(input)
79
+ return self.chat(input)
129
80
 
130
81
  def chat(
131
82
  self,
132
- chat: List[Dict[str, str]],
133
- images: Optional[List[Union[str, Path]]] = None,
83
+ chat: List[Message],
134
84
  ) -> str:
85
+ """Chat with the LMM model.
86
+
87
+ Parameters:
88
+ chat (List[Dict[str, str]]): A list of dictionaries containing the chat
89
+ messages. The messages can be in the format:
90
+ [{"role": "user", "content": "Hello!"}, ...]
91
+ or if it contains media, it should be in the format:
92
+ [{"role": "user", "content": "Hello!", "media": ["image1.jpg", ...]}, ...]
93
+ """
135
94
  fixed_chat = []
136
95
  for c in chat:
137
96
  fixed_c = {"role": c["role"]}
138
97
  fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
139
- fixed_chat.append(fixed_c)
140
-
141
- if images and len(images) > 0:
142
- for image in images:
143
- extension = Path(image).suffix
144
- if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
145
- extension = "jpg"
146
- elif extension.lower() == ".png":
147
- extension = "png"
148
- else:
149
- raise ValueError(f"Unsupported image extension: {extension}")
150
-
151
- encoded_image = encode_image(image)
152
- fixed_chat[0]["content"].append( # type: ignore
153
- {
154
- "type": "image_url",
155
- "image_url": {
156
- "url": f"data:image/{extension};base64,{encoded_image}",
157
- "detail": "low",
98
+ if "media" in c:
99
+ for image in c["media"]:
100
+ extension = Path(image).suffix
101
+ if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
102
+ extension = "jpg"
103
+ elif extension.lower() == ".png":
104
+ extension = "png"
105
+ else:
106
+ raise ValueError(f"Unsupported image extension: {extension}")
107
+ encoded_image = encode_image(image)
108
+ fixed_c["content"].append( # type: ignore
109
+ {
110
+ "type": "image_url",
111
+ "image_url": {
112
+ "url": f"data:image/{extension};base64,{encoded_image}", # type: ignore
113
+ "detail": "low",
114
+ },
158
115
  },
159
- },
160
- )
116
+ )
117
+ fixed_chat.append(fixed_c)
161
118
 
162
119
  response = self.client.chat.completions.create(
163
120
  model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
@@ -168,7 +125,7 @@ class OpenAILMM(LMM):
168
125
  def generate(
169
126
  self,
170
127
  prompt: str,
171
- images: Optional[List[Union[str, Path]]] = None,
128
+ media: Optional[List[Union[str, Path]]] = None,
172
129
  ) -> str:
173
130
  message: List[Dict[str, Any]] = [
174
131
  {
@@ -178,10 +135,10 @@ class OpenAILMM(LMM):
178
135
  ],
179
136
  }
180
137
  ]
181
- if images and len(images) > 0:
182
- for image in images:
183
- extension = Path(image).suffix
184
- encoded_image = encode_image(image)
138
+ if media and len(media) > 0:
139
+ for m in media:
140
+ extension = Path(m).suffix
141
+ encoded_image = encode_image(m)
185
142
  message[0]["content"].append(
186
143
  {
187
144
  "type": "image_url",
@@ -198,9 +155,7 @@ class OpenAILMM(LMM):
198
155
  return cast(str, response.choices[0].message.content)
199
156
 
200
157
  def generate_classifier(self, question: str) -> Callable:
201
- from vision_agent.tools.easytool_tools import CLIP
202
-
203
- api_doc = CLIP.description + "\n" + str(CLIP.usage)
158
+ api_doc = T.get_tool_documentation([T.clip])
204
159
  prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
205
160
  response = self.client.chat.completions.create(
206
161
  model=self.model_name,
@@ -220,12 +175,10 @@ class OpenAILMM(LMM):
220
175
  )
221
176
  raise ValueError("Failed to decode response")
222
177
 
223
- return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x})
178
+ return lambda x: T.clip(x, params["prompt"])
224
179
 
225
180
  def generate_detector(self, question: str) -> Callable:
226
- from vision_agent.tools.easytool_tools import GroundingDINO
227
-
228
- api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
181
+ api_doc = T.get_tool_documentation([T.grounding_dino])
229
182
  prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
230
183
  response = self.client.chat.completions.create(
231
184
  model=self.model_name,
@@ -245,12 +198,10 @@ class OpenAILMM(LMM):
245
198
  )
246
199
  raise ValueError("Failed to decode response")
247
200
 
248
- return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x})
201
+ return lambda x: T.grounding_dino(params["prompt"], x)
249
202
 
250
203
  def generate_segmentor(self, question: str) -> Callable:
251
- from vision_agent.tools.easytool_tools import GroundingSAM
252
-
253
- api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
204
+ api_doc = T.get_tool_documentation([T.grounding_sam])
254
205
  prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
255
206
  response = self.client.chat.completions.create(
256
207
  model=self.model_name,
@@ -270,17 +221,13 @@ class OpenAILMM(LMM):
270
221
  )
271
222
  raise ValueError("Failed to decode response")
272
223
 
273
- return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
224
+ return lambda x: T.grounding_sam(params["prompt"], x)
274
225
 
275
226
  def generate_zero_shot_counter(self, question: str) -> Callable:
276
- from vision_agent.tools.easytool_tools import ZeroShotCounting
277
-
278
- return lambda x: ZeroShotCounting()(**{"image": x})
227
+ return T.zero_shot_counting
279
228
 
280
229
  def generate_image_qa_tool(self, question: str) -> Callable:
281
- from vision_agent.tools.easytool_tools import ImageQuestionAnswering
282
-
283
- return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})
230
+ return lambda x: T.image_question_answering(question, x)
284
231
 
285
232
 
286
233
  class AzureOpenAILMM(OpenAILMM):
@@ -314,12 +261,3 @@ class AzureOpenAILMM(OpenAILMM):
314
261
  if json_mode:
315
262
  kwargs["response_format"] = {"type": "json_object"}
316
263
  self.kwargs = kwargs
317
-
318
-
319
- def get_lmm(name: str) -> LMM:
320
- if name == "openai":
321
- return OpenAILMM(name)
322
- elif name == "llava":
323
- return LLaVALMM(name)
324
- else:
325
- raise ValueError(f"Unknown LMM: {name}, current support openai, llava")