vision-agent 0.2.56__py3-none-any.whl → 0.2.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/__init__.py +1 -2
- vision_agent/agent/agent.py +3 -1
- vision_agent/agent/vision_agent.py +110 -81
- vision_agent/agent/vision_agent_prompts.py +1 -1
- vision_agent/lmm/__init__.py +1 -1
- vision_agent/lmm/lmm.py +54 -116
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/tools.py +3 -3
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.58.dist-info}/METADATA +36 -7
- vision_agent-0.2.58.dist-info/RECORD +23 -0
- vision_agent/agent/agent_coder.py +0 -216
- vision_agent/agent/agent_coder_prompts.py +0 -135
- vision_agent/agent/data_interpreter.py +0 -475
- vision_agent/agent/data_interpreter_prompts.py +0 -186
- vision_agent/agent/easytool.py +0 -346
- vision_agent/agent/easytool_prompts.py +0 -89
- vision_agent/agent/easytool_v2.py +0 -781
- vision_agent/agent/easytool_v2_prompts.py +0 -152
- vision_agent/agent/reflexion.py +0 -299
- vision_agent/agent/reflexion_prompts.py +0 -100
- vision_agent/llm/__init__.py +0 -1
- vision_agent/llm/llm.py +0 -176
- vision_agent/tools/easytool_tools.py +0 -1242
- vision_agent-0.2.56.dist-info/RECORD +0 -36
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.58.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.56.dist-info → vision_agent-0.2.58.dist-info}/WHEEL +0 -0
vision_agent/__init__.py
CHANGED
vision_agent/agent/agent.py
CHANGED
@@ -2,12 +2,14 @@ from abc import ABC, abstractmethod
|
|
2
2
|
from pathlib import Path
|
3
3
|
from typing import Any, Dict, List, Optional, Union
|
4
4
|
|
5
|
+
from vision_agent.lmm import Message
|
6
|
+
|
5
7
|
|
6
8
|
class Agent(ABC):
|
7
9
|
@abstractmethod
|
8
10
|
def __call__(
|
9
11
|
self,
|
10
|
-
input: Union[
|
12
|
+
input: Union[str, List[Message]],
|
11
13
|
media: Optional[Union[str, Path]] = None,
|
12
14
|
) -> str:
|
13
15
|
pass
|
@@ -13,7 +13,6 @@ from rich.style import Style
|
|
13
13
|
from rich.syntax import Syntax
|
14
14
|
from tabulate import tabulate
|
15
15
|
|
16
|
-
from vision_agent.llm.llm import AzureOpenAILLM
|
17
16
|
import vision_agent.tools as T
|
18
17
|
from vision_agent.agent import Agent
|
19
18
|
from vision_agent.agent.vision_agent_prompts import (
|
@@ -25,8 +24,7 @@ from vision_agent.agent.vision_agent_prompts import (
|
|
25
24
|
SIMPLE_TEST,
|
26
25
|
USER_REQ,
|
27
26
|
)
|
28
|
-
from vision_agent.
|
29
|
-
from vision_agent.lmm import LMM, OpenAILMM
|
27
|
+
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
|
30
28
|
from vision_agent.utils import CodeInterpreterFactory, Execution
|
31
29
|
from vision_agent.utils.execute import CodeInterpreter
|
32
30
|
from vision_agent.utils.image_utils import b64_to_pil
|
@@ -133,11 +131,10 @@ def extract_image(
|
|
133
131
|
|
134
132
|
|
135
133
|
def write_plan(
|
136
|
-
chat: List[
|
134
|
+
chat: List[Message],
|
137
135
|
tool_desc: str,
|
138
136
|
working_memory: str,
|
139
|
-
model:
|
140
|
-
media: Optional[Sequence[Union[str, Path]]] = None,
|
137
|
+
model: LMM,
|
141
138
|
) -> List[Dict[str, str]]:
|
142
139
|
chat = copy.deepcopy(chat)
|
143
140
|
if chat[-1]["role"] != "user":
|
@@ -147,18 +144,58 @@ def write_plan(
|
|
147
144
|
context = USER_REQ.format(user_request=user_request)
|
148
145
|
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
|
149
146
|
chat[-1]["content"] = prompt
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
147
|
+
return extract_json(model.chat(chat))["plan"] # type: ignore
|
148
|
+
|
149
|
+
|
150
|
+
def write_code(
|
151
|
+
coder: LMM,
|
152
|
+
chat: List[Message],
|
153
|
+
tool_info: str,
|
154
|
+
feedback: str,
|
155
|
+
) -> str:
|
156
|
+
chat = copy.deepcopy(chat)
|
157
|
+
if chat[-1]["role"] != "user":
|
158
|
+
raise ValueError("Last chat message must be from the user.")
|
159
|
+
|
160
|
+
user_request = chat[-1]["content"]
|
161
|
+
prompt = CODE.format(
|
162
|
+
docstring=tool_info,
|
163
|
+
question=user_request,
|
164
|
+
feedback=feedback,
|
165
|
+
)
|
166
|
+
chat[-1]["content"] = prompt
|
167
|
+
return extract_code(coder(chat))
|
168
|
+
|
169
|
+
|
170
|
+
def write_test(
|
171
|
+
tester: LMM,
|
172
|
+
chat: List[Message],
|
173
|
+
tool_utils: str,
|
174
|
+
code: str,
|
175
|
+
feedback: str,
|
176
|
+
media: Optional[Sequence[Union[str, Path]]] = None,
|
177
|
+
) -> str:
|
178
|
+
chat = copy.deepcopy(chat)
|
179
|
+
if chat[-1]["role"] != "user":
|
180
|
+
raise ValueError("Last chat message must be from the user.")
|
181
|
+
|
182
|
+
user_request = chat[-1]["content"]
|
183
|
+
prompt = SIMPLE_TEST.format(
|
184
|
+
docstring=tool_utils,
|
185
|
+
question=user_request,
|
186
|
+
code=code,
|
187
|
+
feedback=feedback,
|
188
|
+
media=media,
|
189
|
+
)
|
190
|
+
chat[-1]["content"] = prompt
|
191
|
+
return extract_code(tester(chat))
|
155
192
|
|
156
193
|
|
157
194
|
def reflect(
|
158
|
-
chat: List[
|
195
|
+
chat: List[Message],
|
159
196
|
plan: str,
|
160
197
|
code: str,
|
161
|
-
model:
|
198
|
+
model: LMM,
|
162
199
|
) -> Dict[str, Union[str, bool]]:
|
163
200
|
chat = copy.deepcopy(chat)
|
164
201
|
if chat[-1]["role"] != "user":
|
@@ -168,22 +205,22 @@ def reflect(
|
|
168
205
|
context = USER_REQ.format(user_request=user_request)
|
169
206
|
prompt = REFLECT.format(context=context, plan=plan, code=code)
|
170
207
|
chat[-1]["content"] = prompt
|
171
|
-
return extract_json(model
|
208
|
+
return extract_json(model(chat))
|
172
209
|
|
173
210
|
|
174
211
|
def write_and_test_code(
|
175
|
-
|
212
|
+
chat: List[Message],
|
176
213
|
tool_info: str,
|
177
214
|
tool_utils: str,
|
178
215
|
working_memory: List[Dict[str, str]],
|
179
|
-
coder:
|
180
|
-
tester:
|
181
|
-
debugger:
|
216
|
+
coder: LMM,
|
217
|
+
tester: LMM,
|
218
|
+
debugger: LMM,
|
182
219
|
code_interpreter: CodeInterpreter,
|
183
220
|
log_progress: Callable[[Dict[str, Any]], None],
|
184
221
|
verbosity: int = 0,
|
185
222
|
max_retries: int = 3,
|
186
|
-
|
223
|
+
media: Optional[Sequence[Union[str, Path]]] = None,
|
187
224
|
) -> Dict[str, Any]:
|
188
225
|
log_progress(
|
189
226
|
{
|
@@ -191,25 +228,9 @@ def write_and_test_code(
|
|
191
228
|
"status": "started",
|
192
229
|
}
|
193
230
|
)
|
194
|
-
code =
|
195
|
-
|
196
|
-
|
197
|
-
docstring=tool_info,
|
198
|
-
question=task,
|
199
|
-
feedback=format_memory(working_memory),
|
200
|
-
)
|
201
|
-
)
|
202
|
-
)
|
203
|
-
test = extract_code(
|
204
|
-
tester(
|
205
|
-
SIMPLE_TEST.format(
|
206
|
-
docstring=tool_utils,
|
207
|
-
question=task,
|
208
|
-
code=code,
|
209
|
-
feedback=working_memory,
|
210
|
-
media=input_media,
|
211
|
-
)
|
212
|
-
)
|
231
|
+
code = write_code(coder, chat, tool_info, format_memory(working_memory))
|
232
|
+
test = write_test(
|
233
|
+
tester, chat, tool_utils, code, format_memory(working_memory), media
|
213
234
|
)
|
214
235
|
|
215
236
|
log_progress(
|
@@ -392,10 +413,10 @@ class VisionAgent(Agent):
|
|
392
413
|
|
393
414
|
def __init__(
|
394
415
|
self,
|
395
|
-
planner: Optional[
|
396
|
-
coder: Optional[
|
397
|
-
tester: Optional[
|
398
|
-
debugger: Optional[
|
416
|
+
planner: Optional[LMM] = None,
|
417
|
+
coder: Optional[LMM] = None,
|
418
|
+
tester: Optional[LMM] = None,
|
419
|
+
debugger: Optional[LMM] = None,
|
399
420
|
tool_recommender: Optional[Sim] = None,
|
400
421
|
verbosity: int = 0,
|
401
422
|
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
@@ -403,10 +424,10 @@ class VisionAgent(Agent):
|
|
403
424
|
"""Initialize the Vision Agent.
|
404
425
|
|
405
426
|
Parameters:
|
406
|
-
planner (Optional[
|
407
|
-
coder (Optional[
|
408
|
-
tester (Optional[
|
409
|
-
debugger (Optional[
|
427
|
+
planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
|
428
|
+
coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
|
429
|
+
tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
|
430
|
+
debugger (Optional[LMM]): The debugger model to
|
410
431
|
tool_recommender (Optional[Sim]): The tool recommender model to use.
|
411
432
|
verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
|
412
433
|
highest verbosity level which will output all intermediate debugging
|
@@ -418,12 +439,12 @@ class VisionAgent(Agent):
|
|
418
439
|
"""
|
419
440
|
|
420
441
|
self.planner = (
|
421
|
-
|
442
|
+
OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
|
422
443
|
)
|
423
|
-
self.coder =
|
424
|
-
self.tester =
|
444
|
+
self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
|
445
|
+
self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
|
425
446
|
self.debugger = (
|
426
|
-
|
447
|
+
OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
|
427
448
|
)
|
428
449
|
|
429
450
|
self.tool_recommender = (
|
@@ -437,7 +458,7 @@ class VisionAgent(Agent):
|
|
437
458
|
|
438
459
|
def __call__(
|
439
460
|
self,
|
440
|
-
input: Union[
|
461
|
+
input: Union[str, List[Message]],
|
441
462
|
media: Optional[Union[str, Path]] = None,
|
442
463
|
) -> str:
|
443
464
|
"""Chat with Vision Agent and return intermediate information regarding the task.
|
@@ -454,23 +475,26 @@ class VisionAgent(Agent):
|
|
454
475
|
|
455
476
|
if isinstance(input, str):
|
456
477
|
input = [{"role": "user", "content": input}]
|
457
|
-
|
478
|
+
if media is not None:
|
479
|
+
input[0]["media"] = [media]
|
480
|
+
results = self.chat_with_workflow(input)
|
458
481
|
results.pop("working_memory")
|
459
482
|
return results # type: ignore
|
460
483
|
|
461
484
|
def chat_with_workflow(
|
462
485
|
self,
|
463
|
-
chat: List[
|
464
|
-
media: Optional[Union[str, Path]] = None,
|
486
|
+
chat: List[Message],
|
465
487
|
self_reflection: bool = False,
|
466
488
|
display_visualization: bool = False,
|
467
489
|
) -> Dict[str, Any]:
|
468
490
|
"""Chat with Vision Agent and return intermediate information regarding the task.
|
469
491
|
|
470
492
|
Parameters:
|
471
|
-
chat (List[
|
472
|
-
|
473
|
-
|
493
|
+
chat (List[MediaChatItem]): A conversation
|
494
|
+
in the format of:
|
495
|
+
[{"role": "user", "content": "describe your task here..."}]
|
496
|
+
or if it contains media files, it should be in the format of:
|
497
|
+
[{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
|
474
498
|
self_reflection (bool): Whether to reflect on the task and debug the code.
|
475
499
|
display_visualization (bool): If True, it opens a new window locally to
|
476
500
|
show the image(s) created by visualization code (if there is any).
|
@@ -485,11 +509,19 @@ class VisionAgent(Agent):
|
|
485
509
|
|
486
510
|
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
|
487
511
|
with CodeInterpreterFactory.new_instance() as code_interpreter:
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
512
|
+
chat = copy.deepcopy(chat)
|
513
|
+
media_list = []
|
514
|
+
for chat_i in chat:
|
515
|
+
if "media" in chat_i:
|
516
|
+
for media in chat_i["media"]:
|
517
|
+
media = code_interpreter.upload_file(media)
|
518
|
+
chat_i["content"] += f" Media name {media}" # type: ignore
|
519
|
+
media_list.append(media)
|
520
|
+
|
521
|
+
int_chat = cast(
|
522
|
+
List[Message],
|
523
|
+
[{"role": c["role"], "content": c["content"]} for c in chat],
|
524
|
+
)
|
493
525
|
|
494
526
|
code = ""
|
495
527
|
test = ""
|
@@ -507,11 +539,10 @@ class VisionAgent(Agent):
|
|
507
539
|
}
|
508
540
|
)
|
509
541
|
plan_i = write_plan(
|
510
|
-
|
542
|
+
int_chat,
|
511
543
|
T.TOOL_DESCRIPTIONS,
|
512
544
|
format_memory(working_memory),
|
513
545
|
self.planner,
|
514
|
-
media=[media] if media else None,
|
515
546
|
)
|
516
547
|
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])
|
517
548
|
|
@@ -534,9 +565,7 @@ class VisionAgent(Agent):
|
|
534
565
|
self.verbosity,
|
535
566
|
)
|
536
567
|
results = write_and_test_code(
|
537
|
-
|
538
|
-
user_request=chat[0]["content"], subtasks=plan_i_str
|
539
|
-
),
|
568
|
+
chat=int_chat,
|
540
569
|
tool_info=tool_info,
|
541
570
|
tool_utils=T.UTILITIES_DOCSTRING,
|
542
571
|
working_memory=working_memory,
|
@@ -546,7 +575,7 @@ class VisionAgent(Agent):
|
|
546
575
|
code_interpreter=code_interpreter,
|
547
576
|
log_progress=self.log_progress,
|
548
577
|
verbosity=self.verbosity,
|
549
|
-
|
578
|
+
media=media_list,
|
550
579
|
)
|
551
580
|
success = cast(bool, results["success"])
|
552
581
|
code = cast(str, results["code"])
|
@@ -564,7 +593,7 @@ class VisionAgent(Agent):
|
|
564
593
|
}
|
565
594
|
)
|
566
595
|
reflection = reflect(
|
567
|
-
|
596
|
+
int_chat,
|
568
597
|
FULL_TASK.format(
|
569
598
|
user_request=chat[0]["content"], subtasks=plan_i_str
|
570
599
|
),
|
@@ -634,10 +663,10 @@ class AzureVisionAgent(VisionAgent):
|
|
634
663
|
|
635
664
|
def __init__(
|
636
665
|
self,
|
637
|
-
planner: Optional[
|
638
|
-
coder: Optional[
|
639
|
-
tester: Optional[
|
640
|
-
debugger: Optional[
|
666
|
+
planner: Optional[LMM] = None,
|
667
|
+
coder: Optional[LMM] = None,
|
668
|
+
tester: Optional[LMM] = None,
|
669
|
+
debugger: Optional[LMM] = None,
|
641
670
|
tool_recommender: Optional[Sim] = None,
|
642
671
|
verbosity: int = 0,
|
643
672
|
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
@@ -645,10 +674,10 @@ class AzureVisionAgent(VisionAgent):
|
|
645
674
|
"""Initialize the Vision Agent.
|
646
675
|
|
647
676
|
Parameters:
|
648
|
-
planner (Optional[
|
649
|
-
coder (Optional[
|
650
|
-
tester (Optional[
|
651
|
-
debugger (Optional[
|
677
|
+
planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
|
678
|
+
coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
|
679
|
+
tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
|
680
|
+
debugger (Optional[LMM]): The debugger model to
|
652
681
|
tool_recommender (Optional[Sim]): The tool recommender model to use.
|
653
682
|
verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
|
654
683
|
highest verbosity level which will output all intermediate debugging
|
@@ -660,14 +689,14 @@ class AzureVisionAgent(VisionAgent):
|
|
660
689
|
"""
|
661
690
|
super().__init__(
|
662
691
|
planner=(
|
663
|
-
|
692
|
+
AzureOpenAILMM(temperature=0.0, json_mode=True)
|
664
693
|
if planner is None
|
665
694
|
else planner
|
666
695
|
),
|
667
|
-
coder=
|
668
|
-
tester=
|
696
|
+
coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
|
697
|
+
tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
|
669
698
|
debugger=(
|
670
|
-
|
699
|
+
AzureOpenAILMM(temperature=0.0, json_mode=True)
|
671
700
|
if debugger is None
|
672
701
|
else debugger
|
673
702
|
),
|
@@ -171,7 +171,7 @@ This is the documentation for the functions you have access to. You may call any
|
|
171
171
|
**Instructions**:
|
172
172
|
1. Verify the fundamental functionality under normal conditions.
|
173
173
|
2. Ensure each test case is well-documented with comments explaining the scenario it covers.
|
174
|
-
3. Your test case MUST run only on the given
|
174
|
+
3. Your test case MUST run only on the given images which are {media}
|
175
175
|
4. Your test case MUST run only with the given values which is available in the question - {question}
|
176
176
|
5. DO NOT use any non-existent or dummy image or video files that are not provided by the user's instructions.
|
177
177
|
6. DO NOT mock any functions, you must test their functionality as is.
|
vision_agent/lmm/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from .lmm import LMM, AzureOpenAILMM,
|
1
|
+
from .lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
|
vision_agent/lmm/lmm.py
CHANGED
@@ -6,15 +6,13 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
8
8
|
|
9
|
-
import requests
|
10
9
|
from openai import AzureOpenAI, OpenAI
|
11
10
|
|
11
|
+
import vision_agent.tools as T
|
12
12
|
from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
13
13
|
|
14
14
|
_LOGGER = logging.getLogger(__name__)
|
15
15
|
|
16
|
-
_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws"
|
17
|
-
|
18
16
|
|
19
17
|
def encode_image(image: Union[str, Path]) -> str:
|
20
18
|
with open(image, "rb") as f:
|
@@ -22,84 +20,38 @@ def encode_image(image: Union[str, Path]) -> str:
|
|
22
20
|
return encoded_image
|
23
21
|
|
24
22
|
|
23
|
+
TextOrImage = Union[str, List[Union[str, Path]]]
|
24
|
+
Message = Dict[str, TextOrImage]
|
25
|
+
|
26
|
+
|
25
27
|
class LMM(ABC):
|
26
28
|
@abstractmethod
|
27
29
|
def generate(
|
28
|
-
self, prompt: str,
|
30
|
+
self, prompt: str, media: Optional[List[Union[str, Path]]] = None
|
29
31
|
) -> str:
|
30
32
|
pass
|
31
33
|
|
32
34
|
@abstractmethod
|
33
35
|
def chat(
|
34
36
|
self,
|
35
|
-
chat: List[
|
36
|
-
images: Optional[List[Union[str, Path]]] = None,
|
37
|
+
chat: List[Message],
|
37
38
|
) -> str:
|
38
39
|
pass
|
39
40
|
|
40
41
|
@abstractmethod
|
41
42
|
def __call__(
|
42
43
|
self,
|
43
|
-
input: Union[str, List[
|
44
|
-
images: Optional[List[Union[str, Path]]] = None,
|
44
|
+
input: Union[str, List[Message]],
|
45
45
|
) -> str:
|
46
46
|
pass
|
47
47
|
|
48
48
|
|
49
|
-
class LLaVALMM(LMM):
|
50
|
-
r"""An LMM class for the LLaVA-1.6 34B model."""
|
51
|
-
|
52
|
-
def __init__(self, model_name: str):
|
53
|
-
self.model_name = model_name
|
54
|
-
|
55
|
-
def __call__(
|
56
|
-
self,
|
57
|
-
input: Union[str, List[Dict[str, str]]],
|
58
|
-
images: Optional[List[Union[str, Path]]] = None,
|
59
|
-
) -> str:
|
60
|
-
if isinstance(input, str):
|
61
|
-
return self.generate(input, images)
|
62
|
-
return self.chat(input, images)
|
63
|
-
|
64
|
-
def chat(
|
65
|
-
self,
|
66
|
-
chat: List[Dict[str, str]],
|
67
|
-
images: Optional[List[Union[str, Path]]] = None,
|
68
|
-
) -> str:
|
69
|
-
raise NotImplementedError("Chat not supported for LLaVA")
|
70
|
-
|
71
|
-
def generate(
|
72
|
-
self,
|
73
|
-
prompt: str,
|
74
|
-
images: Optional[List[Union[str, Path]]] = None,
|
75
|
-
temperature: float = 0.1,
|
76
|
-
max_new_tokens: int = 1500,
|
77
|
-
) -> str:
|
78
|
-
data = {"prompt": prompt}
|
79
|
-
if images and len(images) > 0:
|
80
|
-
data["image"] = encode_image(images[0])
|
81
|
-
data["temperature"] = temperature # type: ignore
|
82
|
-
data["max_new_tokens"] = max_new_tokens # type: ignore
|
83
|
-
res = requests.post(
|
84
|
-
_LLAVA_ENDPOINT,
|
85
|
-
headers={"Content-Type": "application/json"},
|
86
|
-
json=data,
|
87
|
-
)
|
88
|
-
resp_json: Dict[str, Any] = res.json()
|
89
|
-
if (
|
90
|
-
"statusCode" in resp_json and resp_json["statusCode"] != 200
|
91
|
-
) or "statusCode" not in resp_json:
|
92
|
-
_LOGGER.error(f"Request failed: {resp_json}")
|
93
|
-
raise ValueError(f"Request failed: {resp_json}")
|
94
|
-
return cast(str, resp_json["data"])
|
95
|
-
|
96
|
-
|
97
49
|
class OpenAILMM(LMM):
|
98
50
|
r"""An LMM class for the OpenAI GPT-4 Vision model."""
|
99
51
|
|
100
52
|
def __init__(
|
101
53
|
self,
|
102
|
-
model_name: str = "gpt-
|
54
|
+
model_name: str = "gpt-4o",
|
103
55
|
api_key: Optional[str] = None,
|
104
56
|
max_tokens: int = 1024,
|
105
57
|
json_mode: bool = False,
|
@@ -120,44 +72,49 @@ class OpenAILMM(LMM):
|
|
120
72
|
|
121
73
|
def __call__(
|
122
74
|
self,
|
123
|
-
input: Union[str, List[
|
124
|
-
images: Optional[List[Union[str, Path]]] = None,
|
75
|
+
input: Union[str, List[Message]],
|
125
76
|
) -> str:
|
126
77
|
if isinstance(input, str):
|
127
|
-
return self.generate(input
|
128
|
-
return self.chat(input
|
78
|
+
return self.generate(input)
|
79
|
+
return self.chat(input)
|
129
80
|
|
130
81
|
def chat(
|
131
82
|
self,
|
132
|
-
chat: List[
|
133
|
-
images: Optional[List[Union[str, Path]]] = None,
|
83
|
+
chat: List[Message],
|
134
84
|
) -> str:
|
85
|
+
"""Chat with the LMM model.
|
86
|
+
|
87
|
+
Parameters:
|
88
|
+
chat (List[Dict[str, str]]): A list of dictionaries containing the chat
|
89
|
+
messages. The messages can be in the format:
|
90
|
+
[{"role": "user", "content": "Hello!"}, ...]
|
91
|
+
or if it contains media, it should be in the format:
|
92
|
+
[{"role": "user", "content": "Hello!", "media": ["image1.jpg", ...]}, ...]
|
93
|
+
"""
|
135
94
|
fixed_chat = []
|
136
95
|
for c in chat:
|
137
96
|
fixed_c = {"role": c["role"]}
|
138
97
|
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
"url": f"data:image/{extension};base64,{encoded_image}",
|
157
|
-
"detail": "low",
|
98
|
+
if "media" in c:
|
99
|
+
for image in c["media"]:
|
100
|
+
extension = Path(image).suffix
|
101
|
+
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
|
102
|
+
extension = "jpg"
|
103
|
+
elif extension.lower() == ".png":
|
104
|
+
extension = "png"
|
105
|
+
else:
|
106
|
+
raise ValueError(f"Unsupported image extension: {extension}")
|
107
|
+
encoded_image = encode_image(image)
|
108
|
+
fixed_c["content"].append( # type: ignore
|
109
|
+
{
|
110
|
+
"type": "image_url",
|
111
|
+
"image_url": {
|
112
|
+
"url": f"data:image/{extension};base64,{encoded_image}", # type: ignore
|
113
|
+
"detail": "low",
|
114
|
+
},
|
158
115
|
},
|
159
|
-
|
160
|
-
|
116
|
+
)
|
117
|
+
fixed_chat.append(fixed_c)
|
161
118
|
|
162
119
|
response = self.client.chat.completions.create(
|
163
120
|
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
|
@@ -168,7 +125,7 @@ class OpenAILMM(LMM):
|
|
168
125
|
def generate(
|
169
126
|
self,
|
170
127
|
prompt: str,
|
171
|
-
|
128
|
+
media: Optional[List[Union[str, Path]]] = None,
|
172
129
|
) -> str:
|
173
130
|
message: List[Dict[str, Any]] = [
|
174
131
|
{
|
@@ -178,10 +135,10 @@ class OpenAILMM(LMM):
|
|
178
135
|
],
|
179
136
|
}
|
180
137
|
]
|
181
|
-
if
|
182
|
-
for
|
183
|
-
extension = Path(
|
184
|
-
encoded_image = encode_image(
|
138
|
+
if media and len(media) > 0:
|
139
|
+
for m in media:
|
140
|
+
extension = Path(m).suffix
|
141
|
+
encoded_image = encode_image(m)
|
185
142
|
message[0]["content"].append(
|
186
143
|
{
|
187
144
|
"type": "image_url",
|
@@ -198,9 +155,7 @@ class OpenAILMM(LMM):
|
|
198
155
|
return cast(str, response.choices[0].message.content)
|
199
156
|
|
200
157
|
def generate_classifier(self, question: str) -> Callable:
|
201
|
-
|
202
|
-
|
203
|
-
api_doc = CLIP.description + "\n" + str(CLIP.usage)
|
158
|
+
api_doc = T.get_tool_documentation([T.clip])
|
204
159
|
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
|
205
160
|
response = self.client.chat.completions.create(
|
206
161
|
model=self.model_name,
|
@@ -220,12 +175,10 @@ class OpenAILMM(LMM):
|
|
220
175
|
)
|
221
176
|
raise ValueError("Failed to decode response")
|
222
177
|
|
223
|
-
return lambda x:
|
178
|
+
return lambda x: T.clip(x, params["prompt"])
|
224
179
|
|
225
180
|
def generate_detector(self, question: str) -> Callable:
|
226
|
-
|
227
|
-
|
228
|
-
api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
|
181
|
+
api_doc = T.get_tool_documentation([T.grounding_dino])
|
229
182
|
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
|
230
183
|
response = self.client.chat.completions.create(
|
231
184
|
model=self.model_name,
|
@@ -245,12 +198,10 @@ class OpenAILMM(LMM):
|
|
245
198
|
)
|
246
199
|
raise ValueError("Failed to decode response")
|
247
200
|
|
248
|
-
return lambda x:
|
201
|
+
return lambda x: T.grounding_dino(params["prompt"], x)
|
249
202
|
|
250
203
|
def generate_segmentor(self, question: str) -> Callable:
|
251
|
-
|
252
|
-
|
253
|
-
api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
|
204
|
+
api_doc = T.get_tool_documentation([T.grounding_sam])
|
254
205
|
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
|
255
206
|
response = self.client.chat.completions.create(
|
256
207
|
model=self.model_name,
|
@@ -270,17 +221,13 @@ class OpenAILMM(LMM):
|
|
270
221
|
)
|
271
222
|
raise ValueError("Failed to decode response")
|
272
223
|
|
273
|
-
return lambda x:
|
224
|
+
return lambda x: T.grounding_sam(params["prompt"], x)
|
274
225
|
|
275
226
|
def generate_zero_shot_counter(self, question: str) -> Callable:
|
276
|
-
|
277
|
-
|
278
|
-
return lambda x: ZeroShotCounting()(**{"image": x})
|
227
|
+
return T.zero_shot_counting
|
279
228
|
|
280
229
|
def generate_image_qa_tool(self, question: str) -> Callable:
|
281
|
-
|
282
|
-
|
283
|
-
return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})
|
230
|
+
return lambda x: T.image_question_answering(question, x)
|
284
231
|
|
285
232
|
|
286
233
|
class AzureOpenAILMM(OpenAILMM):
|
@@ -314,12 +261,3 @@ class AzureOpenAILMM(OpenAILMM):
|
|
314
261
|
if json_mode:
|
315
262
|
kwargs["response_format"] = {"type": "json_object"}
|
316
263
|
self.kwargs = kwargs
|
317
|
-
|
318
|
-
|
319
|
-
def get_lmm(name: str) -> LMM:
|
320
|
-
if name == "openai":
|
321
|
-
return OpenAILMM(name)
|
322
|
-
elif name == "llava":
|
323
|
-
return LLaVALMM(name)
|
324
|
-
else:
|
325
|
-
raise ValueError(f"Unknown LMM: {name}, current support openai, llava")
|