vision-agent 0.0.41__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +15 -10
- vision_agent/llm/llm.py +10 -7
- vision_agent/lmm/lmm.py +14 -3
- {vision_agent-0.0.41.dist-info → vision_agent-0.0.42.dist-info}/METADATA +1 -1
- {vision_agent-0.0.41.dist-info → vision_agent-0.0.42.dist-info}/RECORD +7 -7
- {vision_agent-0.0.41.dist-info → vision_agent-0.0.42.dist-info}/LICENSE +0 -0
- {vision_agent-0.0.41.dist-info → vision_agent-0.0.42.dist-info}/WHEEL +0 -0
@@ -256,7 +256,6 @@ def retrieval(
|
|
256
256
|
)
|
257
257
|
if tool_id is None:
|
258
258
|
return {}, ""
|
259
|
-
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
|
260
259
|
|
261
260
|
tool_instructions = tools[tool_id]
|
262
261
|
tool_usage = tool_instructions["usage"]
|
@@ -265,7 +264,6 @@ def retrieval(
|
|
265
264
|
parameters = choose_parameter(
|
266
265
|
model, question, tool_usage, previous_log, reflections
|
267
266
|
)
|
268
|
-
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
|
269
267
|
if parameters is None:
|
270
268
|
return {}, ""
|
271
269
|
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
|
@@ -290,7 +288,7 @@ def retrieval(
|
|
290
288
|
tool_results["call_results"] = call_results
|
291
289
|
|
292
290
|
call_results_str = str(call_results)
|
293
|
-
_LOGGER.info(f"\tCall Results: {call_results_str}")
|
291
|
+
# _LOGGER.info(f"\tCall Results: {call_results_str}")
|
294
292
|
return tool_results, call_results_str
|
295
293
|
|
296
294
|
|
@@ -344,7 +342,9 @@ def self_reflect(
|
|
344
342
|
|
345
343
|
def parse_reflect(reflect: str) -> bool:
|
346
344
|
# GPT-4V has a hard time following directions, so make the criteria less strict
|
347
|
-
return
|
345
|
+
return (
|
346
|
+
"finish" in reflect.lower() and len(reflect) < 100
|
347
|
+
) or "finish" in reflect.lower()[-10:]
|
348
348
|
|
349
349
|
|
350
350
|
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
|
@@ -423,10 +423,16 @@ class VisionAgent(Agent):
|
|
423
423
|
verbose: bool = False,
|
424
424
|
):
|
425
425
|
self.task_model = (
|
426
|
-
OpenAILLM(json_mode=True)
|
426
|
+
OpenAILLM(json_mode=True, temperature=0.1)
|
427
|
+
if task_model is None
|
428
|
+
else task_model
|
429
|
+
)
|
430
|
+
self.answer_model = (
|
431
|
+
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
|
432
|
+
)
|
433
|
+
self.reflect_model = (
|
434
|
+
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
|
427
435
|
)
|
428
|
-
self.answer_model = OpenAILLM() if answer_model is None else answer_model
|
429
|
-
self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model
|
430
436
|
self.max_retries = max_retries
|
431
437
|
|
432
438
|
self.tools = TOOLS
|
@@ -466,7 +472,6 @@ class VisionAgent(Agent):
|
|
466
472
|
for _ in range(self.max_retries):
|
467
473
|
task_list = create_tasks(self.task_model, question, self.tools, reflections)
|
468
474
|
|
469
|
-
_LOGGER.info(f"Task Dependency: {task_list}")
|
470
475
|
task_depend = {"Original Quesiton": question}
|
471
476
|
previous_log = ""
|
472
477
|
answers = []
|
@@ -477,7 +482,6 @@ class VisionAgent(Agent):
|
|
477
482
|
for task in task_list:
|
478
483
|
task_str = task["task"]
|
479
484
|
previous_log = str(task_depend)
|
480
|
-
_LOGGER.info(f"\tSubtask: {task_str}")
|
481
485
|
tool_results, call_results = retrieval(
|
482
486
|
self.task_model,
|
483
487
|
task_str,
|
@@ -492,6 +496,7 @@ class VisionAgent(Agent):
|
|
492
496
|
tool_results["answer"] = answer
|
493
497
|
all_tool_results.append(tool_results)
|
494
498
|
|
499
|
+
_LOGGER.info(f"\tCall Result: {call_results}")
|
495
500
|
_LOGGER.info(f"\tAnswer: {answer}")
|
496
501
|
answers.append({"task": task_str, "answer": answer})
|
497
502
|
task_depend[task["id"]]["answer"] = answer # type: ignore
|
@@ -510,7 +515,7 @@ class VisionAgent(Agent):
|
|
510
515
|
final_answer,
|
511
516
|
visualized_images[0] if len(visualized_images) > 0 else image,
|
512
517
|
)
|
513
|
-
_LOGGER.info(f"
|
518
|
+
_LOGGER.info(f"Reflection: {reflection}")
|
514
519
|
if parse_reflect(reflection):
|
515
520
|
break
|
516
521
|
else:
|
vision_agent/llm/llm.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from typing import Callable, Dict, List, Mapping, Union, cast
|
3
|
+
from typing import Any, Callable, Dict, List, Mapping, Union, cast
|
4
4
|
|
5
5
|
from openai import OpenAI
|
6
6
|
|
@@ -31,30 +31,33 @@ class OpenAILLM(LLM):
|
|
31
31
|
r"""An LLM class for any OpenAI LLM model."""
|
32
32
|
|
33
33
|
def __init__(
|
34
|
-
self,
|
34
|
+
self,
|
35
|
+
model_name: str = "gpt-4-turbo-preview",
|
36
|
+
json_mode: bool = False,
|
37
|
+
**kwargs: Any
|
35
38
|
):
|
36
39
|
self.model_name = model_name
|
37
40
|
self.client = OpenAI()
|
38
|
-
self.
|
41
|
+
self.kwargs = kwargs
|
42
|
+
if json_mode:
|
43
|
+
self.kwargs["response_format"] = {"type": "json_object"}
|
39
44
|
|
40
45
|
def generate(self, prompt: str) -> str:
|
41
|
-
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
|
42
46
|
response = self.client.chat.completions.create(
|
43
47
|
model=self.model_name,
|
44
48
|
messages=[
|
45
49
|
{"role": "user", "content": prompt},
|
46
50
|
],
|
47
|
-
**kwargs,
|
51
|
+
**self.kwargs,
|
48
52
|
)
|
49
53
|
|
50
54
|
return cast(str, response.choices[0].message.content)
|
51
55
|
|
52
56
|
def chat(self, chat: List[Dict[str, str]]) -> str:
|
53
|
-
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
|
54
57
|
response = self.client.chat.completions.create(
|
55
58
|
model=self.model_name,
|
56
59
|
messages=chat, # type: ignore
|
57
|
-
**kwargs,
|
60
|
+
**self.kwargs,
|
58
61
|
)
|
59
62
|
|
60
63
|
return cast(str, response.choices[0].message.content)
|
vision_agent/lmm/lmm.py
CHANGED
@@ -97,11 +97,15 @@ class OpenAILMM(LMM):
|
|
97
97
|
r"""An LMM class for the OpenAI GPT-4 Vision model."""
|
98
98
|
|
99
99
|
def __init__(
|
100
|
-
self,
|
100
|
+
self,
|
101
|
+
model_name: str = "gpt-4-vision-preview",
|
102
|
+
max_tokens: int = 1024,
|
103
|
+
**kwargs: Any,
|
101
104
|
):
|
102
105
|
self.model_name = model_name
|
103
106
|
self.max_tokens = max_tokens
|
104
107
|
self.client = OpenAI()
|
108
|
+
self.kwargs = kwargs
|
105
109
|
|
106
110
|
def __call__(
|
107
111
|
self,
|
@@ -123,6 +127,13 @@ class OpenAILMM(LMM):
|
|
123
127
|
|
124
128
|
if image:
|
125
129
|
extension = Path(image).suffix
|
130
|
+
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
|
131
|
+
extension = "jpg"
|
132
|
+
elif extension.lower() == ".png":
|
133
|
+
extension = "png"
|
134
|
+
else:
|
135
|
+
raise ValueError(f"Unsupported image extension: {extension}")
|
136
|
+
|
126
137
|
encoded_image = encode_image(image)
|
127
138
|
fixed_chat[0]["content"].append( # type: ignore
|
128
139
|
{
|
@@ -135,7 +146,7 @@ class OpenAILMM(LMM):
|
|
135
146
|
)
|
136
147
|
|
137
148
|
response = self.client.chat.completions.create(
|
138
|
-
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore
|
149
|
+
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
|
139
150
|
)
|
140
151
|
|
141
152
|
return cast(str, response.choices[0].message.content)
|
@@ -163,7 +174,7 @@ class OpenAILMM(LMM):
|
|
163
174
|
)
|
164
175
|
|
165
176
|
response = self.client.chat.completions.create(
|
166
|
-
model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore
|
177
|
+
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
|
167
178
|
)
|
168
179
|
return cast(str, response.choices[0].message.content)
|
169
180
|
|
@@ -5,7 +5,7 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
|
|
5
5
|
vision_agent/agent/easytool_prompts.py,sha256=uNp12LOFRLr3i2zLhNuLuyFms2-s8es2t6P6h76QDow,4493
|
6
6
|
vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=UPGkt_qgHBMUY0VPVoF-BqhR0d_6WPjjrhbYLBYOtnQ,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=P2melU6XQCCiiL1C_4QsxGUaWbwahuJA90eIcQJTR4U,17449
|
9
9
|
vision_agent/agent/vision_agent_prompts.py,sha256=otaDRsaHc7bqw_tgWTnu-eUcFeOzBFrn9sPU7_xr2VQ,6151
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
11
|
vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
|
@@ -13,14 +13,14 @@ vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,
|
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
14
|
vision_agent/image_utils.py,sha256=XiOLpHAvlk55URw6iG7hl1OY71FVRA9_25b650amZXA,4420
|
15
15
|
vision_agent/llm/__init__.py,sha256=fBKsIjL4z08eA0QYx6wvhRe4Nkp2pJ4VrZK0-uUL5Ec,32
|
16
|
-
vision_agent/llm/llm.py,sha256=
|
16
|
+
vision_agent/llm/llm.py,sha256=l8ZVh6vCZOJBHfenfOoHwPySXEUQoNt_gbL14gkvu2g,3904
|
17
17
|
vision_agent/lmm/__init__.py,sha256=I8mbeNUajTfWVNqLsuFQVOaNBDlkIhYp9DFU8H4kB7g,51
|
18
|
-
vision_agent/lmm/lmm.py,sha256=
|
18
|
+
vision_agent/lmm/lmm.py,sha256=s_A3SKCoWm2biOt-gS9PXOsa9l-zrmR6mInLjAqam-A,8438
|
19
19
|
vision_agent/tools/__init__.py,sha256=AKN-T659HpwVearRnkCd6wWNoJ6K5kW9gAZwb8IQSLE,235
|
20
20
|
vision_agent/tools/prompts.py,sha256=9RBbyqlNlExsGKlJ89Jkph83DAEJ8PCVGaHoNbyN7TM,1416
|
21
21
|
vision_agent/tools/tools.py,sha256=aMTBxxaXQp33HwplOS8xrgfbsTJ8e1pwO6byR7HcTJI,23447
|
22
22
|
vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
|
23
|
-
vision_agent-0.0.
|
24
|
-
vision_agent-0.0.
|
25
|
-
vision_agent-0.0.
|
26
|
-
vision_agent-0.0.
|
23
|
+
vision_agent-0.0.42.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
24
|
+
vision_agent-0.0.42.dist-info/METADATA,sha256=r523uVvu-DsNoA-H-18O2JXF4J9G2nZ2cDSmjXUFq_M,5324
|
25
|
+
vision_agent-0.0.42.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
26
|
+
vision_agent-0.0.42.dist-info/RECORD,,
|
File without changes
|
File without changes
|