ag2 0.4.1__py3-none-any.whl → 0.4.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- ag2-0.4.2b1.dist-info/METADATA +19 -0
- ag2-0.4.2b1.dist-info/RECORD +6 -0
- ag2-0.4.2b1.dist-info/top_level.txt +1 -0
- ag2-0.4.1.dist-info/METADATA +0 -500
- ag2-0.4.1.dist-info/RECORD +0 -158
- ag2-0.4.1.dist-info/top_level.txt +0 -1
- autogen/__init__.py +0 -17
- autogen/_pydantic.py +0 -116
- autogen/agentchat/__init__.py +0 -42
- autogen/agentchat/agent.py +0 -142
- autogen/agentchat/assistant_agent.py +0 -85
- autogen/agentchat/chat.py +0 -306
- autogen/agentchat/contrib/__init__.py +0 -0
- autogen/agentchat/contrib/agent_builder.py +0 -788
- autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
- autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
- autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
- autogen/agentchat/contrib/agent_eval/task.py +0 -43
- autogen/agentchat/contrib/agent_optimizer.py +0 -450
- autogen/agentchat/contrib/capabilities/__init__.py +0 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
- autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
- autogen/agentchat/contrib/capabilities/teachability.py +0 -406
- autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
- autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
- autogen/agentchat/contrib/capabilities/transforms.py +0 -565
- autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
- autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
- autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
- autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
- autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
- autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
- autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
- autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
- autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
- autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
- autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
- autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
- autogen/agentchat/contrib/captainagent.py +0 -490
- autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
- autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
- autogen/agentchat/contrib/graph_rag/document.py +0 -30
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
- autogen/agentchat/contrib/img_utils.py +0 -390
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
- autogen/agentchat/contrib/llava_agent.py +0 -176
- autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
- autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
- autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
- autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
- autogen/agentchat/contrib/swarm_agent.py +0 -463
- autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
- autogen/agentchat/contrib/tool_retriever.py +0 -120
- autogen/agentchat/contrib/vectordb/__init__.py +0 -0
- autogen/agentchat/contrib/vectordb/base.py +0 -243
- autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
- autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
- autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
- autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
- autogen/agentchat/contrib/vectordb/utils.py +0 -126
- autogen/agentchat/contrib/web_surfer.py +0 -305
- autogen/agentchat/conversable_agent.py +0 -2908
- autogen/agentchat/groupchat.py +0 -1668
- autogen/agentchat/user_proxy_agent.py +0 -109
- autogen/agentchat/utils.py +0 -207
- autogen/browser_utils.py +0 -291
- autogen/cache/__init__.py +0 -10
- autogen/cache/abstract_cache_base.py +0 -78
- autogen/cache/cache.py +0 -182
- autogen/cache/cache_factory.py +0 -85
- autogen/cache/cosmos_db_cache.py +0 -150
- autogen/cache/disk_cache.py +0 -109
- autogen/cache/in_memory_cache.py +0 -61
- autogen/cache/redis_cache.py +0 -128
- autogen/code_utils.py +0 -745
- autogen/coding/__init__.py +0 -22
- autogen/coding/base.py +0 -113
- autogen/coding/docker_commandline_code_executor.py +0 -262
- autogen/coding/factory.py +0 -45
- autogen/coding/func_with_reqs.py +0 -203
- autogen/coding/jupyter/__init__.py +0 -22
- autogen/coding/jupyter/base.py +0 -32
- autogen/coding/jupyter/docker_jupyter_server.py +0 -164
- autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
- autogen/coding/jupyter/jupyter_client.py +0 -224
- autogen/coding/jupyter/jupyter_code_executor.py +0 -161
- autogen/coding/jupyter/local_jupyter_server.py +0 -168
- autogen/coding/local_commandline_code_executor.py +0 -410
- autogen/coding/markdown_code_extractor.py +0 -44
- autogen/coding/utils.py +0 -57
- autogen/exception_utils.py +0 -46
- autogen/extensions/__init__.py +0 -0
- autogen/formatting_utils.py +0 -76
- autogen/function_utils.py +0 -362
- autogen/graph_utils.py +0 -148
- autogen/io/__init__.py +0 -15
- autogen/io/base.py +0 -105
- autogen/io/console.py +0 -43
- autogen/io/websockets.py +0 -213
- autogen/logger/__init__.py +0 -11
- autogen/logger/base_logger.py +0 -140
- autogen/logger/file_logger.py +0 -287
- autogen/logger/logger_factory.py +0 -29
- autogen/logger/logger_utils.py +0 -42
- autogen/logger/sqlite_logger.py +0 -459
- autogen/math_utils.py +0 -356
- autogen/oai/__init__.py +0 -33
- autogen/oai/anthropic.py +0 -428
- autogen/oai/bedrock.py +0 -606
- autogen/oai/cerebras.py +0 -270
- autogen/oai/client.py +0 -1148
- autogen/oai/client_utils.py +0 -167
- autogen/oai/cohere.py +0 -453
- autogen/oai/completion.py +0 -1216
- autogen/oai/gemini.py +0 -469
- autogen/oai/groq.py +0 -281
- autogen/oai/mistral.py +0 -279
- autogen/oai/ollama.py +0 -582
- autogen/oai/openai_utils.py +0 -811
- autogen/oai/together.py +0 -343
- autogen/retrieve_utils.py +0 -487
- autogen/runtime_logging.py +0 -163
- autogen/token_count_utils.py +0 -259
- autogen/types.py +0 -20
- autogen/version.py +0 -7
- {ag2-0.4.1.dist-info → ag2-0.4.2b1.dist-info}/LICENSE +0 -0
- {ag2-0.4.1.dist-info → ag2-0.4.2b1.dist-info}/NOTICE.md +0 -0
- {ag2-0.4.1.dist-info → ag2-0.4.2b1.dist-info}/WHEEL +0 -0
|
@@ -1,123 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
-
# SPDX-License-Identifier: MIT
|
|
7
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
8
|
-
|
|
9
|
-
from autogen import OpenAIWrapper
|
|
10
|
-
from autogen.agentchat import Agent, ConversableAgent
|
|
11
|
-
from autogen.agentchat.contrib.vectordb.utils import get_logger
|
|
12
|
-
|
|
13
|
-
logger = get_logger(__name__)
|
|
14
|
-
|
|
15
|
-
try:
|
|
16
|
-
from llama_index.core.agent.runner.base import AgentRunner
|
|
17
|
-
from llama_index.core.base.llms.types import ChatMessage
|
|
18
|
-
from llama_index.core.chat_engine.types import AgentChatResponse
|
|
19
|
-
from pydantic import BaseModel
|
|
20
|
-
|
|
21
|
-
# Add Pydantic configuration to allow arbitrary types
|
|
22
|
-
# Added to mitigate PydanticSchemaGenerationError
|
|
23
|
-
class Config:
|
|
24
|
-
arbitrary_types_allowed = True
|
|
25
|
-
|
|
26
|
-
BaseModel.model_config = Config
|
|
27
|
-
|
|
28
|
-
except ImportError as e:
|
|
29
|
-
logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
|
|
30
|
-
raise e
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class LLamaIndexConversableAgent(ConversableAgent):
|
|
34
|
-
def __init__(
|
|
35
|
-
self,
|
|
36
|
-
name: str,
|
|
37
|
-
llama_index_agent: AgentRunner,
|
|
38
|
-
description: Optional[str] = None,
|
|
39
|
-
**kwargs,
|
|
40
|
-
):
|
|
41
|
-
"""
|
|
42
|
-
Args:
|
|
43
|
-
name (str): agent name.
|
|
44
|
-
llama_index_agent (AgentRunner): llama index agent.
|
|
45
|
-
Please override this attribute if you want to reprogram the agent.
|
|
46
|
-
description (str): a short description of the agent. This description is used by other agents
|
|
47
|
-
(e.g. the GroupChatManager) to decide when to call upon this agent.
|
|
48
|
-
**kwargs (dict): Please refer to other kwargs in
|
|
49
|
-
[ConversableAgent](../conversable_agent#__init__).
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
if llama_index_agent is None:
|
|
53
|
-
raise ValueError("llama_index_agent must be provided")
|
|
54
|
-
|
|
55
|
-
if description is None or description.isspace():
|
|
56
|
-
raise ValueError("description must be provided")
|
|
57
|
-
|
|
58
|
-
super().__init__(
|
|
59
|
-
name,
|
|
60
|
-
description=description,
|
|
61
|
-
**kwargs,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
self._llama_index_agent = llama_index_agent
|
|
65
|
-
|
|
66
|
-
# Override the `generate_oai_reply`
|
|
67
|
-
self.replace_reply_func(ConversableAgent.generate_oai_reply, LLamaIndexConversableAgent._generate_oai_reply)
|
|
68
|
-
|
|
69
|
-
self.replace_reply_func(ConversableAgent.a_generate_oai_reply, LLamaIndexConversableAgent._a_generate_oai_reply)
|
|
70
|
-
|
|
71
|
-
def _generate_oai_reply(
|
|
72
|
-
self,
|
|
73
|
-
messages: Optional[List[Dict]] = None,
|
|
74
|
-
sender: Optional[Agent] = None,
|
|
75
|
-
config: Optional[OpenAIWrapper] = None,
|
|
76
|
-
) -> Tuple[bool, Union[str, Dict, None]]:
|
|
77
|
-
"""Generate a reply using autogen.oai."""
|
|
78
|
-
user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
|
|
79
|
-
|
|
80
|
-
chatResponse: AgentChatResponse = self._llama_index_agent.chat(message=user_message, chat_history=history)
|
|
81
|
-
|
|
82
|
-
extracted_response = chatResponse.response
|
|
83
|
-
|
|
84
|
-
return (True, extracted_response)
|
|
85
|
-
|
|
86
|
-
async def _a_generate_oai_reply(
|
|
87
|
-
self,
|
|
88
|
-
messages: Optional[List[Dict]] = None,
|
|
89
|
-
sender: Optional[Agent] = None,
|
|
90
|
-
config: Optional[OpenAIWrapper] = None,
|
|
91
|
-
) -> Tuple[bool, Union[str, Dict, None]]:
|
|
92
|
-
"""Generate a reply using autogen.oai."""
|
|
93
|
-
user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
|
|
94
|
-
|
|
95
|
-
chatResponse: AgentChatResponse = await self._llama_index_agent.achat(
|
|
96
|
-
message=user_message, chat_history=history
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
extracted_response = chatResponse.response
|
|
100
|
-
|
|
101
|
-
return (True, extracted_response)
|
|
102
|
-
|
|
103
|
-
def _extract_message_and_history(
|
|
104
|
-
self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None
|
|
105
|
-
) -> Tuple[str, List[ChatMessage]]:
|
|
106
|
-
"""Extract the message and history from the messages."""
|
|
107
|
-
if not messages:
|
|
108
|
-
messages = self._oai_messages[sender]
|
|
109
|
-
|
|
110
|
-
if not messages:
|
|
111
|
-
return "", []
|
|
112
|
-
|
|
113
|
-
message = messages[-1].get("content", "")
|
|
114
|
-
|
|
115
|
-
history = messages[:-1]
|
|
116
|
-
history_messages: List[ChatMessage] = []
|
|
117
|
-
for history_message in history:
|
|
118
|
-
content = history_message.get("content", "")
|
|
119
|
-
role = history_message.get("role", "user")
|
|
120
|
-
if role:
|
|
121
|
-
if role == "user" or role == "assistant":
|
|
122
|
-
history_messages.append(ChatMessage(content=content, role=role, additional_kwargs={}))
|
|
123
|
-
return message, history_messages
|
|
@@ -1,176 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
-
# SPDX-License-Identifier: MIT
|
|
7
|
-
import json
|
|
8
|
-
import logging
|
|
9
|
-
from typing import List, Optional, Tuple
|
|
10
|
-
|
|
11
|
-
import replicate
|
|
12
|
-
import requests
|
|
13
|
-
|
|
14
|
-
from autogen.agentchat.agent import Agent
|
|
15
|
-
from autogen.agentchat.contrib.img_utils import get_image_data, llava_formatter
|
|
16
|
-
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
|
17
|
-
from autogen.code_utils import content_str
|
|
18
|
-
|
|
19
|
-
from ...formatting_utils import colored
|
|
20
|
-
|
|
21
|
-
logger = logging.getLogger(__name__)
|
|
22
|
-
|
|
23
|
-
# we will override the following variables later.
|
|
24
|
-
SEP = "###"
|
|
25
|
-
|
|
26
|
-
DEFAULT_LLAVA_SYS_MSG = "You are an AI agent and you can view images."
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class LLaVAAgent(MultimodalConversableAgent):
|
|
30
|
-
def __init__(
|
|
31
|
-
self,
|
|
32
|
-
name: str,
|
|
33
|
-
system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG,
|
|
34
|
-
*args,
|
|
35
|
-
**kwargs,
|
|
36
|
-
):
|
|
37
|
-
"""
|
|
38
|
-
Args:
|
|
39
|
-
name (str): agent name.
|
|
40
|
-
system_message (str): system message for the ChatCompletion inference.
|
|
41
|
-
Please override this attribute if you want to reprogram the agent.
|
|
42
|
-
**kwargs (dict): Please refer to other kwargs in
|
|
43
|
-
[ConversableAgent](../conversable_agent#__init__).
|
|
44
|
-
"""
|
|
45
|
-
super().__init__(
|
|
46
|
-
name,
|
|
47
|
-
system_message=system_message,
|
|
48
|
-
*args,
|
|
49
|
-
**kwargs,
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
assert self.llm_config is not None, "llm_config must be provided."
|
|
53
|
-
self.register_reply([Agent, None], reply_func=LLaVAAgent._image_reply, position=2)
|
|
54
|
-
|
|
55
|
-
def _image_reply(self, messages=None, sender=None, config=None):
|
|
56
|
-
# Note: we did not use "llm_config" yet.
|
|
57
|
-
|
|
58
|
-
if all((messages is None, sender is None)):
|
|
59
|
-
error_msg = f"Either {messages=} or {sender=} must be provided."
|
|
60
|
-
logger.error(error_msg)
|
|
61
|
-
raise AssertionError(error_msg)
|
|
62
|
-
|
|
63
|
-
if messages is None:
|
|
64
|
-
messages = self._oai_messages[sender]
|
|
65
|
-
|
|
66
|
-
# The formats for LLaVA and GPT are different. So, we manually handle them here.
|
|
67
|
-
images = []
|
|
68
|
-
prompt = content_str(self.system_message) + "\n"
|
|
69
|
-
for msg in messages:
|
|
70
|
-
role = "Human" if msg["role"] == "user" else "Assistant"
|
|
71
|
-
# pdb.set_trace()
|
|
72
|
-
images += [d["image_url"]["url"] for d in msg["content"] if d["type"] == "image_url"]
|
|
73
|
-
content_prompt = content_str(msg["content"])
|
|
74
|
-
prompt += f"{SEP}{role}: {content_prompt}\n"
|
|
75
|
-
prompt += "\n" + SEP + "Assistant: "
|
|
76
|
-
|
|
77
|
-
# TODO: PIL to base64
|
|
78
|
-
images = [get_image_data(im) for im in images]
|
|
79
|
-
print(colored(prompt, "blue"))
|
|
80
|
-
|
|
81
|
-
out = ""
|
|
82
|
-
retry = 10
|
|
83
|
-
while len(out) == 0 and retry > 0:
|
|
84
|
-
# image names will be inferred automatically from llava_call
|
|
85
|
-
out = llava_call_binary(
|
|
86
|
-
prompt=prompt,
|
|
87
|
-
images=images,
|
|
88
|
-
config_list=self.llm_config["config_list"],
|
|
89
|
-
temperature=self.llm_config.get("temperature", 0.5),
|
|
90
|
-
max_new_tokens=self.llm_config.get("max_new_tokens", 2000),
|
|
91
|
-
)
|
|
92
|
-
retry -= 1
|
|
93
|
-
|
|
94
|
-
assert out != "", "Empty response from LLaVA."
|
|
95
|
-
|
|
96
|
-
return True, out
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _llava_call_binary_with_config(
|
|
100
|
-
prompt: str, images: list, config: dict, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
|
|
101
|
-
):
|
|
102
|
-
if config["base_url"].find("0.0.0.0") >= 0 or config["base_url"].find("localhost") >= 0:
|
|
103
|
-
llava_mode = "local"
|
|
104
|
-
else:
|
|
105
|
-
llava_mode = "remote"
|
|
106
|
-
|
|
107
|
-
if llava_mode == "local":
|
|
108
|
-
headers = {"User-Agent": "LLaVA Client"}
|
|
109
|
-
pload = {
|
|
110
|
-
"model": config["model"],
|
|
111
|
-
"prompt": prompt,
|
|
112
|
-
"max_new_tokens": max_new_tokens,
|
|
113
|
-
"temperature": temperature,
|
|
114
|
-
"stop": SEP,
|
|
115
|
-
"images": images,
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
response = requests.post(
|
|
119
|
-
config["base_url"].rstrip("/") + "/worker_generate_stream", headers=headers, json=pload, stream=False
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
123
|
-
if chunk:
|
|
124
|
-
data = json.loads(chunk.decode("utf-8"))
|
|
125
|
-
output = data["text"].split(SEP)[-1]
|
|
126
|
-
elif llava_mode == "remote":
|
|
127
|
-
# The Replicate version of the model only support 1 image for now.
|
|
128
|
-
img = "data:image/jpeg;base64," + images[0]
|
|
129
|
-
response = replicate.run(
|
|
130
|
-
config["base_url"], input={"image": img, "prompt": prompt.replace("<image>", " "), "seed": seed}
|
|
131
|
-
)
|
|
132
|
-
# The yorickvp/llava-13b model can stream output as it's running.
|
|
133
|
-
# The predict method returns an iterator, and you can iterate over that output.
|
|
134
|
-
output = ""
|
|
135
|
-
for item in response:
|
|
136
|
-
# https://replicate.com/yorickvp/llava-13b/versions/2facb4a474a0462c15041b78b1ad70952ea46b5ec6ad29583c0b29dbd4249591/api#output-schema
|
|
137
|
-
output += item
|
|
138
|
-
|
|
139
|
-
# Remove the prompt and the space.
|
|
140
|
-
output = output.replace(prompt, "").strip().rstrip()
|
|
141
|
-
return output
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def llava_call_binary(
|
|
145
|
-
prompt: str, images: list, config_list: list, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
|
|
146
|
-
):
|
|
147
|
-
# TODO 1: add caching around the LLaVA call to save compute and cost
|
|
148
|
-
# TODO 2: add `seed` to ensure reproducibility. The seed is not working now.
|
|
149
|
-
|
|
150
|
-
for config in config_list:
|
|
151
|
-
try:
|
|
152
|
-
return _llava_call_binary_with_config(prompt, images, config, max_new_tokens, temperature, seed)
|
|
153
|
-
except Exception as e:
|
|
154
|
-
print(f"Error: {e}")
|
|
155
|
-
continue
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def llava_call(prompt: str, llm_config: dict) -> str:
|
|
159
|
-
"""
|
|
160
|
-
Makes a call to the LLaVA service to generate text based on a given prompt
|
|
161
|
-
"""
|
|
162
|
-
|
|
163
|
-
prompt, images = llava_formatter(prompt, order_image_tokens=False)
|
|
164
|
-
|
|
165
|
-
for im in images:
|
|
166
|
-
if len(im) == 0:
|
|
167
|
-
raise RuntimeError("An image is empty!")
|
|
168
|
-
|
|
169
|
-
return llava_call_binary(
|
|
170
|
-
prompt,
|
|
171
|
-
images,
|
|
172
|
-
config_list=llm_config["config_list"],
|
|
173
|
-
max_new_tokens=llm_config.get("max_new_tokens", 2000),
|
|
174
|
-
temperature=llm_config.get("temperature", 0.5),
|
|
175
|
-
seed=llm_config.get("seed", None),
|
|
176
|
-
)
|