vision-agent 0.2.66__py3-none-any.whl → 0.2.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/lmm/__init__.py +1 -1
- vision_agent/lmm/lmm.py +82 -0
- {vision_agent-0.2.66.dist-info → vision_agent-0.2.67.dist-info}/METADATA +1 -1
- {vision_agent-0.2.66.dist-info → vision_agent-0.2.67.dist-info}/RECORD +6 -6
- {vision_agent-0.2.66.dist-info → vision_agent-0.2.67.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.66.dist-info → vision_agent-0.2.67.dist-info}/WHEEL +0 -0
vision_agent/lmm/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from .lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
|
1
|
+
from .lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
|
vision_agent/lmm/lmm.py
CHANGED
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
8
8
|
|
9
|
+
import requests
|
9
10
|
from openai import AzureOpenAI, OpenAI
|
10
11
|
|
11
12
|
import vision_agent.tools as T
|
@@ -267,3 +268,84 @@ class AzureOpenAILMM(OpenAILMM):
|
|
267
268
|
if json_mode:
|
268
269
|
kwargs["response_format"] = {"type": "json_object"}
|
269
270
|
self.kwargs = kwargs
|
271
|
+
|
272
|
+
|
273
|
+
class OllamaLMM(LMM):
|
274
|
+
r"""An LMM class for the ollama."""
|
275
|
+
|
276
|
+
def __init__(
|
277
|
+
self,
|
278
|
+
model_name: str = "llava",
|
279
|
+
base_url: Optional[str] = "http://localhost:11434/api",
|
280
|
+
json_mode: bool = False,
|
281
|
+
**kwargs: Any,
|
282
|
+
):
|
283
|
+
self.url = base_url
|
284
|
+
self.model_name = model_name
|
285
|
+
self.json_mode = json_mode
|
286
|
+
self.stream = False
|
287
|
+
|
288
|
+
def __call__(
|
289
|
+
self,
|
290
|
+
input: Union[str, List[Message]],
|
291
|
+
) -> str:
|
292
|
+
if isinstance(input, str):
|
293
|
+
return self.generate(input)
|
294
|
+
return self.chat(input)
|
295
|
+
|
296
|
+
def chat(
|
297
|
+
self,
|
298
|
+
chat: List[Message],
|
299
|
+
) -> str:
|
300
|
+
"""Chat with the LMM model.
|
301
|
+
|
302
|
+
Parameters:
|
303
|
+
chat (List[Dict[str, str]]): A list of dictionaries containing the chat
|
304
|
+
messages. The messages can be in the format:
|
305
|
+
[{"role": "user", "content": "Hello!"}, ...]
|
306
|
+
or if it contains media, it should be in the format:
|
307
|
+
[{"role": "user", "content": "Hello!", "media": ["image1.jpg", ...]}, ...]
|
308
|
+
"""
|
309
|
+
fixed_chat = []
|
310
|
+
for message in chat:
|
311
|
+
if "media" in message:
|
312
|
+
message["images"] = [encode_image(m) for m in message["media"]]
|
313
|
+
del message["media"]
|
314
|
+
fixed_chat.append(message)
|
315
|
+
url = f"{self.url}/chat"
|
316
|
+
model = self.model_name
|
317
|
+
messages = fixed_chat
|
318
|
+
data = {"model": model, "messages": messages, "stream": self.stream}
|
319
|
+
json_data = json.dumps(data)
|
320
|
+
response = requests.post(url, data=json_data)
|
321
|
+
if response.status_code != 200:
|
322
|
+
raise ValueError(f"Request failed with status code {response.status_code}")
|
323
|
+
response = response.json()
|
324
|
+
return response["message"]["content"] # type: ignore
|
325
|
+
|
326
|
+
def generate(
|
327
|
+
self,
|
328
|
+
prompt: str,
|
329
|
+
media: Optional[List[Union[str, Path]]] = None,
|
330
|
+
) -> str:
|
331
|
+
|
332
|
+
url = f"{self.url}/generate"
|
333
|
+
data = {
|
334
|
+
"model": self.model_name,
|
335
|
+
"prompt": prompt,
|
336
|
+
"images": [],
|
337
|
+
"stream": self.stream,
|
338
|
+
}
|
339
|
+
|
340
|
+
json_data = json.dumps(data)
|
341
|
+
if media and len(media) > 0:
|
342
|
+
for m in media:
|
343
|
+
data["images"].append(encode_image(m)) # type: ignore
|
344
|
+
|
345
|
+
response = requests.post(url, data=json_data)
|
346
|
+
|
347
|
+
if response.status_code != 200:
|
348
|
+
raise ValueError(f"Request failed with status code {response.status_code}")
|
349
|
+
|
350
|
+
response = response.json()
|
351
|
+
return response["response"] # type: ignore
|
@@ -5,8 +5,8 @@ vision_agent/agent/vision_agent.py,sha256=HC63BP4jPiR4lJLEkKQ-zMV5C5JwjnuZvc7hVj
|
|
5
5
|
vision_agent/agent/vision_agent_prompts.py,sha256=jpGJjrxDrxZej5SSgsTEK1sSYttgkTiZqxZAU1jWfvk,8656
|
6
6
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
8
|
-
vision_agent/lmm/__init__.py,sha256=
|
9
|
-
vision_agent/lmm/lmm.py,sha256=
|
8
|
+
vision_agent/lmm/__init__.py,sha256=bw24xyQJHGzmph5e-bKCiTh9AX6tRFI2OUd0mofxjZI,68
|
9
|
+
vision_agent/lmm/lmm.py,sha256=V7jfU94HwA-SiQLY14USHrSGtagVKCNGjZhW5MyKipo,11547
|
10
10
|
vision_agent/tools/__init__.py,sha256=aE1O8cMeLDPO50Sc-CuAQ_Akh0viz7vBxDcVeZNqsA0,1604
|
11
11
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
12
12
|
vision_agent/tools/tool_utils.py,sha256=wzRacbUpqk9hhfX_Y08rL8qP0XCN2w-8IZoYLi3Upn4,869
|
@@ -17,7 +17,7 @@ vision_agent/utils/image_utils.py,sha256=_cdiS5YrLzqkq_ZgFUO897m5M4_SCIThwUy4lOk
|
|
17
17
|
vision_agent/utils/sim.py,sha256=ci6Eta73dDgLP1Ajtknbgmf1g8aAvBHqlVQvBuLMKXQ,4427
|
18
18
|
vision_agent/utils/type_defs.py,sha256=BlI8ywWHAplC7kYWLvt4AOdnKpEW3qWEFm-GEOSkrFQ,1792
|
19
19
|
vision_agent/utils/video.py,sha256=rNmU9KEIkZB5-EztZNlUiKYN0mm_55A_2VGUM0QpqLA,8779
|
20
|
-
vision_agent-0.2.
|
21
|
-
vision_agent-0.2.
|
22
|
-
vision_agent-0.2.
|
23
|
-
vision_agent-0.2.
|
20
|
+
vision_agent-0.2.67.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
21
|
+
vision_agent-0.2.67.dist-info/METADATA,sha256=BZKENJv_iaNU-XDqc5z4Ygx7k2jR4_7BbIdGoJE3voA,8363
|
22
|
+
vision_agent-0.2.67.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
23
|
+
vision_agent-0.2.67.dist-info/RECORD,,
|
File without changes
|
File without changes
|