autobyteus 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autobyteus/agent/context/agent_config.py +6 -1
- autobyteus/agent/context/agent_runtime_state.py +7 -1
- autobyteus/agent/handlers/llm_user_message_ready_event_handler.py +30 -7
- autobyteus/agent/handlers/tool_result_event_handler.py +100 -88
- autobyteus/agent/handlers/user_input_message_event_handler.py +22 -25
- autobyteus/agent/llm_response_processor/provider_aware_tool_usage_processor.py +7 -1
- autobyteus/agent/message/__init__.py +7 -5
- autobyteus/agent/message/agent_input_user_message.py +6 -16
- autobyteus/agent/message/context_file.py +24 -24
- autobyteus/agent/message/context_file_type.py +29 -8
- autobyteus/agent/message/multimodal_message_builder.py +47 -0
- autobyteus/agent/streaming/stream_event_payloads.py +23 -4
- autobyteus/agent/system_prompt_processor/tool_manifest_injector_processor.py +6 -2
- autobyteus/agent/tool_invocation.py +27 -2
- autobyteus/agent_team/agent_team_builder.py +22 -1
- autobyteus/agent_team/bootstrap_steps/agent_configuration_preparation_step.py +9 -2
- autobyteus/agent_team/context/agent_team_config.py +1 -0
- autobyteus/agent_team/context/agent_team_runtime_state.py +0 -2
- autobyteus/llm/api/autobyteus_llm.py +33 -33
- autobyteus/llm/api/bedrock_llm.py +13 -5
- autobyteus/llm/api/claude_llm.py +13 -27
- autobyteus/llm/api/gemini_llm.py +108 -42
- autobyteus/llm/api/groq_llm.py +4 -3
- autobyteus/llm/api/mistral_llm.py +97 -51
- autobyteus/llm/api/nvidia_llm.py +6 -5
- autobyteus/llm/api/ollama_llm.py +37 -12
- autobyteus/llm/api/openai_compatible_llm.py +91 -91
- autobyteus/llm/autobyteus_provider.py +1 -1
- autobyteus/llm/base_llm.py +42 -139
- autobyteus/llm/extensions/base_extension.py +6 -6
- autobyteus/llm/extensions/token_usage_tracking_extension.py +3 -2
- autobyteus/llm/llm_factory.py +131 -61
- autobyteus/llm/ollama_provider_resolver.py +1 -0
- autobyteus/llm/providers.py +1 -0
- autobyteus/llm/token_counter/token_counter_factory.py +3 -1
- autobyteus/llm/user_message.py +43 -35
- autobyteus/llm/utils/llm_config.py +34 -18
- autobyteus/llm/utils/media_payload_formatter.py +99 -0
- autobyteus/llm/utils/messages.py +32 -25
- autobyteus/llm/utils/response_types.py +9 -3
- autobyteus/llm/utils/token_usage.py +6 -5
- autobyteus/multimedia/__init__.py +31 -0
- autobyteus/multimedia/audio/__init__.py +11 -0
- autobyteus/multimedia/audio/api/__init__.py +4 -0
- autobyteus/multimedia/audio/api/autobyteus_audio_client.py +59 -0
- autobyteus/multimedia/audio/api/gemini_audio_client.py +219 -0
- autobyteus/multimedia/audio/audio_client_factory.py +120 -0
- autobyteus/multimedia/audio/audio_model.py +97 -0
- autobyteus/multimedia/audio/autobyteus_audio_provider.py +108 -0
- autobyteus/multimedia/audio/base_audio_client.py +40 -0
- autobyteus/multimedia/image/__init__.py +11 -0
- autobyteus/multimedia/image/api/__init__.py +9 -0
- autobyteus/multimedia/image/api/autobyteus_image_client.py +97 -0
- autobyteus/multimedia/image/api/gemini_image_client.py +188 -0
- autobyteus/multimedia/image/api/openai_image_client.py +142 -0
- autobyteus/multimedia/image/autobyteus_image_provider.py +109 -0
- autobyteus/multimedia/image/base_image_client.py +67 -0
- autobyteus/multimedia/image/image_client_factory.py +118 -0
- autobyteus/multimedia/image/image_model.py +97 -0
- autobyteus/multimedia/providers.py +5 -0
- autobyteus/multimedia/runtimes.py +8 -0
- autobyteus/multimedia/utils/__init__.py +10 -0
- autobyteus/multimedia/utils/api_utils.py +19 -0
- autobyteus/multimedia/utils/multimedia_config.py +29 -0
- autobyteus/multimedia/utils/response_types.py +13 -0
- autobyteus/task_management/tools/publish_task_plan.py +4 -16
- autobyteus/task_management/tools/update_task_status.py +4 -19
- autobyteus/tools/__init__.py +5 -4
- autobyteus/tools/base_tool.py +98 -29
- autobyteus/tools/browser/standalone/__init__.py +0 -1
- autobyteus/tools/google_search.py +149 -0
- autobyteus/tools/mcp/schema_mapper.py +29 -71
- autobyteus/tools/multimedia/__init__.py +8 -0
- autobyteus/tools/multimedia/audio_tools.py +116 -0
- autobyteus/tools/multimedia/image_tools.py +186 -0
- autobyteus/tools/parameter_schema.py +82 -89
- autobyteus/tools/pydantic_schema_converter.py +81 -0
- autobyteus/tools/tool_category.py +1 -0
- autobyteus/tools/usage/formatters/default_json_example_formatter.py +89 -20
- autobyteus/tools/usage/formatters/default_xml_example_formatter.py +115 -41
- autobyteus/tools/usage/formatters/default_xml_schema_formatter.py +50 -20
- autobyteus/tools/usage/formatters/gemini_json_example_formatter.py +55 -22
- autobyteus/tools/usage/formatters/google_json_example_formatter.py +54 -21
- autobyteus/tools/usage/formatters/openai_json_example_formatter.py +53 -23
- autobyteus/tools/usage/parsers/default_xml_tool_usage_parser.py +270 -94
- autobyteus/tools/usage/parsers/provider_aware_tool_usage_parser.py +5 -2
- autobyteus/tools/usage/providers/tool_manifest_provider.py +43 -16
- autobyteus/tools/usage/registries/tool_formatting_registry.py +9 -2
- autobyteus/tools/usage/registries/tool_usage_parser_registry.py +9 -2
- autobyteus-1.1.7.dist-info/METADATA +204 -0
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/RECORD +98 -71
- examples/run_browser_agent.py +1 -1
- examples/run_google_slides_agent.py +2 -2
- examples/run_mcp_google_slides_client.py +1 -1
- examples/run_sqlite_agent.py +1 -1
- autobyteus/llm/utils/image_payload_formatter.py +0 -89
- autobyteus/tools/ask_user_input.py +0 -40
- autobyteus/tools/browser/standalone/factory/google_search_factory.py +0 -25
- autobyteus/tools/browser/standalone/google_search_ui.py +0 -126
- autobyteus-1.1.5.dist-info/METADATA +0 -161
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/WHEEL +0 -0
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/licenses/LICENSE +0 -0
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/top_level.txt +0 -0
autobyteus/llm/api/claude_llm.py
CHANGED
|
@@ -8,14 +8,14 @@ from autobyteus.llm.utils.llm_config import LLMConfig
|
|
|
8
8
|
from autobyteus.llm.utils.messages import MessageRole, Message
|
|
9
9
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
10
10
|
from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
|
|
11
|
+
from autobyteus.llm.user_message import LLMUserMessage
|
|
11
12
|
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
14
15
|
class ClaudeLLM(BaseLLM):
|
|
15
16
|
def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
|
|
16
|
-
# Provide defaults if not specified
|
|
17
17
|
if model is None:
|
|
18
|
-
model = LLMModel
|
|
18
|
+
model = LLMModel['claude-4-sonnet']
|
|
19
19
|
if llm_config is None:
|
|
20
20
|
llm_config = LLMConfig()
|
|
21
21
|
|
|
@@ -37,22 +37,22 @@ class ClaudeLLM(BaseLLM):
|
|
|
37
37
|
raise ValueError(f"Failed to initialize Anthropic client: {str(e)}")
|
|
38
38
|
|
|
39
39
|
def _get_non_system_messages(self) -> List[Dict]:
|
|
40
|
-
|
|
41
|
-
Returns all messages excluding system messages for Anthropic API compatibility.
|
|
42
|
-
"""
|
|
40
|
+
# NOTE: This will need to be updated to handle multimodal messages for Claude
|
|
43
41
|
return [msg.to_dict() for msg in self.messages if msg.role != MessageRole.SYSTEM]
|
|
44
42
|
|
|
45
43
|
def _create_token_usage(self, input_tokens: int, output_tokens: int) -> TokenUsage:
|
|
46
|
-
"""Convert Anthropic usage data to TokenUsage format."""
|
|
47
44
|
return TokenUsage(
|
|
48
45
|
prompt_tokens=input_tokens,
|
|
49
46
|
completion_tokens=output_tokens,
|
|
50
47
|
total_tokens=input_tokens + output_tokens
|
|
51
48
|
)
|
|
52
49
|
|
|
53
|
-
async def _send_user_message_to_llm(self, user_message:
|
|
50
|
+
async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
|
|
54
51
|
self.add_user_message(user_message)
|
|
55
52
|
|
|
53
|
+
# NOTE: This implementation does not yet support multimodal inputs for Claude.
|
|
54
|
+
# It will only send the text content.
|
|
55
|
+
|
|
56
56
|
try:
|
|
57
57
|
response = self.client.messages.create(
|
|
58
58
|
model=self.model.value,
|
|
@@ -81,12 +81,15 @@ class ClaudeLLM(BaseLLM):
|
|
|
81
81
|
raise ValueError(f"Error in Claude API call: {str(e)}")
|
|
82
82
|
|
|
83
83
|
async def _stream_user_message_to_llm(
|
|
84
|
-
self, user_message:
|
|
84
|
+
self, user_message: LLMUserMessage, **kwargs
|
|
85
85
|
) -> AsyncGenerator[ChunkResponse, None]:
|
|
86
86
|
self.add_user_message(user_message)
|
|
87
87
|
complete_response = ""
|
|
88
88
|
final_message = None
|
|
89
89
|
|
|
90
|
+
# NOTE: This implementation does not yet support multimodal inputs for Claude.
|
|
91
|
+
# It will only send the text content.
|
|
92
|
+
|
|
90
93
|
try:
|
|
91
94
|
with self.client.messages.stream(
|
|
92
95
|
model=self.model.value,
|
|
@@ -96,30 +99,13 @@ class ClaudeLLM(BaseLLM):
|
|
|
96
99
|
messages=self._get_non_system_messages(),
|
|
97
100
|
) as stream:
|
|
98
101
|
for event in stream:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
if event.type == "message_start":
|
|
102
|
-
logger.debug(f"Message Start: {event.message}")
|
|
103
|
-
|
|
104
|
-
elif event.type == "content_block_start":
|
|
105
|
-
logger.debug(f"Content Block Start at index {event.index}: {event.content_block}")
|
|
106
|
-
|
|
107
|
-
elif event.type == "content_block_delta" and event.delta.type == "text_delta":
|
|
108
|
-
logger.debug(f"Text Delta: {event.delta.text}")
|
|
102
|
+
if event.type == "content_block_delta" and event.delta.type == "text_delta":
|
|
109
103
|
complete_response += event.delta.text
|
|
110
104
|
yield ChunkResponse(
|
|
111
105
|
content=event.delta.text,
|
|
112
106
|
is_complete=False
|
|
113
107
|
)
|
|
114
108
|
|
|
115
|
-
elif event.type == "message_delta":
|
|
116
|
-
logger.debug(f"Message Delta: Stop Reason - {event.delta.stop_reason}, "
|
|
117
|
-
f"Stop Sequence - {event.delta.stop_sequence}")
|
|
118
|
-
|
|
119
|
-
elif event.type == "content_block_stop":
|
|
120
|
-
logger.debug(f"Content Block Stop at index {event.index}: {event.content_block}")
|
|
121
|
-
|
|
122
|
-
# Get final message for token usage
|
|
123
109
|
final_message = stream.get_final_message()
|
|
124
110
|
if final_message:
|
|
125
111
|
token_usage = self._create_token_usage(
|
|
@@ -140,4 +126,4 @@ class ClaudeLLM(BaseLLM):
|
|
|
140
126
|
raise ValueError(f"Error in Claude API streaming: {str(e)}")
|
|
141
127
|
|
|
142
128
|
async def cleanup(self):
|
|
143
|
-
super().cleanup()
|
|
129
|
+
await super().cleanup()
|
autobyteus/llm/api/gemini_llm.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Dict,
|
|
3
|
-
import google.generativeai as genai
|
|
2
|
+
from typing import Dict, List, AsyncGenerator, Any
|
|
3
|
+
import google.generativeai as genai # CHANGED: Using the older 'google.generativeai' library
|
|
4
4
|
import os
|
|
5
5
|
from autobyteus.llm.models import LLMModel
|
|
6
6
|
from autobyteus.llm.base_llm import BaseLLM
|
|
@@ -8,68 +8,93 @@ from autobyteus.llm.utils.llm_config import LLMConfig
|
|
|
8
8
|
from autobyteus.llm.utils.messages import MessageRole, Message
|
|
9
9
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
10
10
|
from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
|
|
11
|
+
from autobyteus.llm.user_message import LLMUserMessage
|
|
11
12
|
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
15
|
+
def _format_gemini_history(messages: List[Message]) -> List[Dict[str, Any]]:
|
|
16
|
+
"""
|
|
17
|
+
Formats internal message history for the Gemini API.
|
|
18
|
+
This function remains compatible with the older library.
|
|
19
|
+
"""
|
|
20
|
+
history = []
|
|
21
|
+
# System message is handled separately in the model initialization
|
|
22
|
+
for msg in messages:
|
|
23
|
+
if msg.role in [MessageRole.USER, MessageRole.ASSISTANT]:
|
|
24
|
+
role = 'model' if msg.role == MessageRole.ASSISTANT else 'user'
|
|
25
|
+
history.append({"role": role, "parts": [{"text": msg.content}]})
|
|
26
|
+
return history
|
|
27
|
+
|
|
14
28
|
class GeminiLLM(BaseLLM):
|
|
15
29
|
def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
|
|
16
|
-
self.generation_config = {
|
|
17
|
-
"temperature": 0,
|
|
18
|
-
"top_p": 0.95,
|
|
19
|
-
"top_k": 64,
|
|
20
|
-
"max_output_tokens": 8192,
|
|
21
|
-
"response_mime_type": "text/plain",
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
# Provide defaults if not specified
|
|
25
30
|
if model is None:
|
|
26
|
-
model = LLMModel.
|
|
31
|
+
model = LLMModel['gemini-2.5-flash'] # Note: Ensure model name is compatible, e.g., 'gemini-1.5-flash-latest'
|
|
27
32
|
if llm_config is None:
|
|
28
33
|
llm_config = LLMConfig()
|
|
29
|
-
|
|
34
|
+
|
|
30
35
|
super().__init__(model=model, llm_config=llm_config)
|
|
31
|
-
|
|
32
|
-
|
|
36
|
+
|
|
37
|
+
# CHANGED: Initialization flow. Configure API key and then instantiate the model.
|
|
38
|
+
self.initialize()
|
|
39
|
+
|
|
40
|
+
system_instruction = self.system_message if self.system_message else None
|
|
41
|
+
|
|
42
|
+
self.model = genai.GenerativeModel(
|
|
43
|
+
model_name=self.model.value,
|
|
44
|
+
system_instruction=system_instruction
|
|
45
|
+
)
|
|
33
46
|
|
|
34
|
-
@
|
|
35
|
-
def initialize(
|
|
47
|
+
@staticmethod
|
|
48
|
+
def initialize():
|
|
49
|
+
"""
|
|
50
|
+
CHANGED: This method now configures the genai library with the API key
|
|
51
|
+
instead of creating a client instance.
|
|
52
|
+
"""
|
|
36
53
|
api_key = os.environ.get("GEMINI_API_KEY")
|
|
37
54
|
if not api_key:
|
|
38
55
|
logger.error("GEMINI_API_KEY environment variable is not set.")
|
|
39
|
-
raise ValueError(
|
|
40
|
-
"GEMINI_API_KEY environment variable is not set. "
|
|
41
|
-
"Please set this variable in your environment."
|
|
42
|
-
)
|
|
56
|
+
raise ValueError("GEMINI_API_KEY environment variable is not set.")
|
|
43
57
|
try:
|
|
44
58
|
genai.configure(api_key=api_key)
|
|
45
|
-
return genai
|
|
46
59
|
except Exception as e:
|
|
47
|
-
logger.error(f"Failed to
|
|
48
|
-
raise ValueError(f"Failed to
|
|
60
|
+
logger.error(f"Failed to configure Gemini client: {str(e)}")
|
|
61
|
+
raise ValueError(f"Failed to configure Gemini client: {str(e)}")
|
|
49
62
|
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
63
|
+
def _get_generation_config(self) -> Dict[str, Any]:
|
|
64
|
+
"""
|
|
65
|
+
CHANGED: Builds the generation config as a dictionary.
|
|
66
|
+
'thinking_config' is not available in the old library.
|
|
67
|
+
'system_instruction' is passed during model initialization.
|
|
68
|
+
"""
|
|
69
|
+
# Basic configuration, you can expand this with temperature, top_p, etc.
|
|
70
|
+
# from self.llm_config if needed.
|
|
71
|
+
config = {
|
|
72
|
+
"response_mime_type": "text/plain",
|
|
73
|
+
# Example: "temperature": self.llm_config.temperature
|
|
74
|
+
}
|
|
75
|
+
return config
|
|
60
76
|
|
|
61
|
-
async def _send_user_message_to_llm(self, user_message:
|
|
77
|
+
async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
|
|
62
78
|
self.add_user_message(user_message)
|
|
79
|
+
|
|
63
80
|
try:
|
|
64
|
-
self.
|
|
65
|
-
|
|
81
|
+
history = _format_gemini_history(self.messages)
|
|
82
|
+
generation_config = self._get_generation_config()
|
|
83
|
+
|
|
84
|
+
# CHANGED: API call now uses the model instance directly.
|
|
85
|
+
response = await self.model.generate_content_async(
|
|
86
|
+
contents=history,
|
|
87
|
+
generation_config=generation_config,
|
|
88
|
+
)
|
|
89
|
+
|
|
66
90
|
assistant_message = response.text
|
|
67
91
|
self.add_assistant_message(assistant_message)
|
|
68
92
|
|
|
93
|
+
# CHANGED: Token usage is extracted from 'usage_metadata'.
|
|
69
94
|
token_usage = TokenUsage(
|
|
70
|
-
prompt_tokens=
|
|
71
|
-
completion_tokens=
|
|
72
|
-
total_tokens=
|
|
95
|
+
prompt_tokens=response.usage_metadata.prompt_token_count,
|
|
96
|
+
completion_tokens=response.usage_metadata.candidates_token_count,
|
|
97
|
+
total_tokens=response.usage_metadata.total_token_count
|
|
73
98
|
)
|
|
74
99
|
|
|
75
100
|
return CompleteResponse(
|
|
@@ -80,6 +105,47 @@ class GeminiLLM(BaseLLM):
|
|
|
80
105
|
logger.error(f"Error in Gemini API call: {str(e)}")
|
|
81
106
|
raise ValueError(f"Error in Gemini API call: {str(e)}")
|
|
82
107
|
|
|
108
|
+
async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
|
|
109
|
+
self.add_user_message(user_message)
|
|
110
|
+
complete_response = ""
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
history = _format_gemini_history(self.messages)
|
|
114
|
+
generation_config = self._get_generation_config()
|
|
115
|
+
|
|
116
|
+
# CHANGED: API call for streaming is now part of generate_content_async.
|
|
117
|
+
response_stream = await self.model.generate_content_async(
|
|
118
|
+
contents=history,
|
|
119
|
+
generation_config=generation_config,
|
|
120
|
+
stream=True
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
async for chunk in response_stream:
|
|
124
|
+
chunk_text = chunk.text
|
|
125
|
+
complete_response += chunk_text
|
|
126
|
+
yield ChunkResponse(
|
|
127
|
+
content=chunk_text,
|
|
128
|
+
is_complete=False
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.add_assistant_message(complete_response)
|
|
132
|
+
|
|
133
|
+
# NOTE: The old library's async stream does not easily expose token usage.
|
|
134
|
+
# Keeping it at 0, consistent with your original implementation.
|
|
135
|
+
token_usage = TokenUsage(
|
|
136
|
+
prompt_tokens=0,
|
|
137
|
+
completion_tokens=0,
|
|
138
|
+
total_tokens=0
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
yield ChunkResponse(
|
|
142
|
+
content="",
|
|
143
|
+
is_complete=True,
|
|
144
|
+
usage=token_usage
|
|
145
|
+
)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.error(f"Error in Gemini API streaming call: {str(e)}")
|
|
148
|
+
raise ValueError(f"Error in Gemini API streaming call: {str(e)}")
|
|
149
|
+
|
|
83
150
|
async def cleanup(self):
|
|
84
|
-
|
|
85
|
-
super().cleanup()
|
|
151
|
+
await super().cleanup()
|
autobyteus/llm/api/groq_llm.py
CHANGED
|
@@ -7,6 +7,7 @@ from autobyteus.llm.utils.llm_config import LLMConfig
|
|
|
7
7
|
from autobyteus.llm.utils.messages import MessageRole, Message
|
|
8
8
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
9
9
|
from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
|
|
10
|
+
from autobyteus.llm.user_message import LLMUserMessage
|
|
10
11
|
|
|
11
12
|
logger = logging.getLogger(__name__)
|
|
12
13
|
|
|
@@ -36,7 +37,7 @@ class GroqLLM(BaseLLM):
|
|
|
36
37
|
except Exception as e:
|
|
37
38
|
raise ValueError(f"Failed to initialize Groq client: {str(e)}")
|
|
38
39
|
|
|
39
|
-
async def _send_user_message_to_llm(self, user_message:
|
|
40
|
+
async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
|
|
40
41
|
self.add_user_message(user_message)
|
|
41
42
|
try:
|
|
42
43
|
# Placeholder for sending message to Groq API
|
|
@@ -58,7 +59,7 @@ class GroqLLM(BaseLLM):
|
|
|
58
59
|
raise ValueError(f"Error in Groq API call: {str(e)}")
|
|
59
60
|
|
|
60
61
|
async def _stream_user_message_to_llm(
|
|
61
|
-
self, user_message:
|
|
62
|
+
self, user_message: LLMUserMessage, **kwargs
|
|
62
63
|
) -> AsyncGenerator[ChunkResponse, None]:
|
|
63
64
|
self.add_user_message(user_message)
|
|
64
65
|
complete_response = ""
|
|
@@ -90,4 +91,4 @@ class GroqLLM(BaseLLM):
|
|
|
90
91
|
raise ValueError(f"Error in Groq API streaming: {str(e)}")
|
|
91
92
|
|
|
92
93
|
async def cleanup(self):
|
|
93
|
-
super().cleanup()
|
|
94
|
+
await super().cleanup()
|
|
@@ -1,45 +1,91 @@
|
|
|
1
|
-
from typing import Dict, Optional, List, AsyncGenerator
|
|
1
|
+
from typing import Dict, Optional, List, Any, AsyncGenerator, Union
|
|
2
2
|
import os
|
|
3
3
|
import logging
|
|
4
|
+
import httpx
|
|
5
|
+
import asyncio
|
|
4
6
|
from autobyteus.llm.models import LLMModel
|
|
5
7
|
from autobyteus.llm.base_llm import BaseLLM
|
|
6
8
|
from mistralai import Mistral
|
|
7
|
-
from autobyteus.llm.utils.messages import
|
|
9
|
+
from autobyteus.llm.utils.messages import Message, MessageRole
|
|
8
10
|
from autobyteus.llm.utils.llm_config import LLMConfig
|
|
9
11
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
10
12
|
from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
|
|
13
|
+
from autobyteus.llm.user_message import LLMUserMessage
|
|
14
|
+
from autobyteus.llm.utils.media_payload_formatter import image_source_to_base64, get_mime_type, is_valid_image_path
|
|
11
15
|
|
|
12
|
-
# Configure logger
|
|
13
16
|
logger = logging.getLogger(__name__)
|
|
14
17
|
|
|
18
|
+
async def _format_mistral_messages(messages: List[Message]) -> List[Dict[str, Any]]:
|
|
19
|
+
"""Formats a list of internal Message objects into a list of dictionaries for the Mistral API."""
|
|
20
|
+
mistral_messages = []
|
|
21
|
+
for msg in messages:
|
|
22
|
+
# Skip empty messages from non-system roles as Mistral API may reject them
|
|
23
|
+
if not msg.content and not msg.image_urls and msg.role != MessageRole.SYSTEM:
|
|
24
|
+
continue
|
|
25
|
+
|
|
26
|
+
content: Union[str, List[Dict[str, Any]]]
|
|
27
|
+
|
|
28
|
+
if msg.image_urls:
|
|
29
|
+
content_parts: List[Dict[str, Any]] = []
|
|
30
|
+
if msg.content:
|
|
31
|
+
content_parts.append({"type": "text", "text": msg.content})
|
|
32
|
+
|
|
33
|
+
image_tasks = [image_source_to_base64(url) for url in msg.image_urls]
|
|
34
|
+
try:
|
|
35
|
+
base64_images = await asyncio.gather(*image_tasks)
|
|
36
|
+
for i, b64_image in enumerate(base64_images):
|
|
37
|
+
original_url = msg.image_urls[i]
|
|
38
|
+
mime_type = get_mime_type(original_url) if is_valid_image_path(original_url) else "image/jpeg"
|
|
39
|
+
data_uri = f"data:{mime_type};base64,{b64_image}"
|
|
40
|
+
|
|
41
|
+
# Mistral's format for image parts
|
|
42
|
+
content_parts.append({
|
|
43
|
+
"type": "image_url",
|
|
44
|
+
"image_url": {
|
|
45
|
+
"url": data_uri
|
|
46
|
+
}
|
|
47
|
+
})
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"Error processing images for Mistral: {e}")
|
|
50
|
+
|
|
51
|
+
if msg.audio_urls:
|
|
52
|
+
logger.warning("MistralLLM does not yet support audio; skipping.")
|
|
53
|
+
if msg.video_urls:
|
|
54
|
+
logger.warning("MistralLLM does not yet support video; skipping.")
|
|
55
|
+
|
|
56
|
+
content = content_parts
|
|
57
|
+
else:
|
|
58
|
+
content = msg.content or ""
|
|
59
|
+
|
|
60
|
+
mistral_messages.append({"role": msg.role.value, "content": content})
|
|
61
|
+
|
|
62
|
+
return mistral_messages
|
|
63
|
+
|
|
64
|
+
|
|
15
65
|
class MistralLLM(BaseLLM):
|
|
16
66
|
def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
|
|
17
|
-
# Provide defaults if not specified
|
|
18
67
|
if model is None:
|
|
19
|
-
model = LLMModel
|
|
68
|
+
model = LLMModel['mistral-large']
|
|
20
69
|
if llm_config is None:
|
|
21
70
|
llm_config = LLMConfig()
|
|
22
71
|
|
|
23
72
|
super().__init__(model=model, llm_config=llm_config)
|
|
24
|
-
self.
|
|
73
|
+
self.http_client = httpx.AsyncClient()
|
|
74
|
+
self.client: Mistral = self._initialize()
|
|
25
75
|
logger.info(f"MistralLLM initialized with model: {self.model}")
|
|
26
76
|
|
|
27
|
-
|
|
28
|
-
def initialize(cls):
|
|
77
|
+
def _initialize(self) -> Mistral:
|
|
29
78
|
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
|
30
79
|
if not mistral_api_key:
|
|
31
80
|
logger.error("MISTRAL_API_KEY environment variable is not set")
|
|
32
|
-
raise ValueError(
|
|
33
|
-
"MISTRAL_API_KEY environment variable is not set. "
|
|
34
|
-
"Please set this variable in your environment."
|
|
35
|
-
)
|
|
81
|
+
raise ValueError("MISTRAL_API_KEY environment variable is not set.")
|
|
36
82
|
try:
|
|
37
|
-
return Mistral(api_key=mistral_api_key)
|
|
83
|
+
return Mistral(api_key=mistral_api_key, client=self.http_client)
|
|
38
84
|
except Exception as e:
|
|
39
85
|
logger.error(f"Failed to initialize Mistral client: {str(e)}")
|
|
40
86
|
raise ValueError(f"Failed to initialize Mistral client: {str(e)}")
|
|
41
87
|
|
|
42
|
-
def _create_token_usage(self, usage_data:
|
|
88
|
+
def _create_token_usage(self, usage_data: Any) -> TokenUsage:
|
|
43
89
|
"""Convert Mistral usage data to TokenUsage format."""
|
|
44
90
|
return TokenUsage(
|
|
45
91
|
prompt_tokens=usage_data.prompt_tokens,
|
|
@@ -48,26 +94,26 @@ class MistralLLM(BaseLLM):
|
|
|
48
94
|
)
|
|
49
95
|
|
|
50
96
|
async def _send_user_message_to_llm(
|
|
51
|
-
self, user_message:
|
|
97
|
+
self, user_message: LLMUserMessage, **kwargs
|
|
52
98
|
) -> CompleteResponse:
|
|
53
99
|
self.add_user_message(user_message)
|
|
54
|
-
|
|
100
|
+
|
|
55
101
|
try:
|
|
56
|
-
mistral_messages =
|
|
102
|
+
mistral_messages = await _format_mistral_messages(self.messages)
|
|
57
103
|
|
|
58
|
-
chat_response = self.client.chat.
|
|
104
|
+
chat_response = await self.client.chat.complete_async(
|
|
59
105
|
model=self.model.value,
|
|
60
106
|
messages=mistral_messages,
|
|
107
|
+
temperature=self.config.temperature,
|
|
108
|
+
max_tokens=self.config.max_tokens,
|
|
109
|
+
top_p=self.config.top_p,
|
|
61
110
|
)
|
|
62
111
|
|
|
63
|
-
assistant_message = chat_response.choices.message.content
|
|
112
|
+
assistant_message = chat_response.choices[0].message.content
|
|
64
113
|
self.add_assistant_message(assistant_message)
|
|
65
114
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if hasattr(chat_response, 'usage') and chat_response.usage:
|
|
69
|
-
token_usage = self._create_token_usage(chat_response.usage)
|
|
70
|
-
logger.debug(f"Token usage recorded: {token_usage}")
|
|
115
|
+
token_usage = self._create_token_usage(chat_response.usage)
|
|
116
|
+
logger.debug(f"Token usage recorded: {token_usage}")
|
|
71
117
|
|
|
72
118
|
return CompleteResponse(
|
|
73
119
|
content=assistant_message,
|
|
@@ -78,48 +124,48 @@ class MistralLLM(BaseLLM):
|
|
|
78
124
|
raise ValueError(f"Error in Mistral API call: {str(e)}")
|
|
79
125
|
|
|
80
126
|
async def _stream_user_message_to_llm(
|
|
81
|
-
self, user_message:
|
|
127
|
+
self, user_message: LLMUserMessage, **kwargs
|
|
82
128
|
) -> AsyncGenerator[ChunkResponse, None]:
|
|
83
129
|
self.add_user_message(user_message)
|
|
84
130
|
|
|
131
|
+
accumulated_message = ""
|
|
132
|
+
final_usage = None
|
|
133
|
+
|
|
85
134
|
try:
|
|
86
|
-
mistral_messages =
|
|
87
|
-
|
|
88
|
-
stream =
|
|
135
|
+
mistral_messages = await _format_mistral_messages(self.messages)
|
|
136
|
+
|
|
137
|
+
stream = self.client.chat.stream_async(
|
|
89
138
|
model=self.model.value,
|
|
90
139
|
messages=mistral_messages,
|
|
140
|
+
temperature=self.config.temperature,
|
|
141
|
+
max_tokens=self.config.max_tokens,
|
|
142
|
+
top_p=self.config.top_p,
|
|
91
143
|
)
|
|
92
144
|
|
|
93
|
-
accumulated_message = ""
|
|
94
|
-
|
|
95
145
|
async for chunk in stream:
|
|
96
|
-
if chunk.
|
|
97
|
-
token = chunk.
|
|
146
|
+
if chunk.choices and chunk.choices[0].delta.content is not None:
|
|
147
|
+
token = chunk.choices[0].delta.content
|
|
98
148
|
accumulated_message += token
|
|
99
149
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
usage=token_usage
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
# After streaming is complete, store the full message
|
|
150
|
+
yield ChunkResponse(content=token, is_complete=False)
|
|
151
|
+
|
|
152
|
+
if hasattr(chunk, 'usage') and chunk.usage:
|
|
153
|
+
final_usage = self._create_token_usage(chunk.usage)
|
|
154
|
+
|
|
155
|
+
# Yield the final chunk with usage data
|
|
156
|
+
yield ChunkResponse(
|
|
157
|
+
content="",
|
|
158
|
+
is_complete=True,
|
|
159
|
+
usage=final_usage
|
|
160
|
+
)
|
|
161
|
+
|
|
116
162
|
self.add_assistant_message(accumulated_message)
|
|
117
163
|
except Exception as e:
|
|
118
164
|
logger.error(f"Error in Mistral API streaming call: {str(e)}")
|
|
119
165
|
raise ValueError(f"Error in Mistral API streaming call: {str(e)}")
|
|
120
166
|
|
|
121
167
|
async def cleanup(self):
|
|
122
|
-
# Clean up any resources if needed
|
|
123
168
|
logger.debug("Cleaning up MistralLLM instance")
|
|
124
|
-
self.
|
|
125
|
-
|
|
169
|
+
if self.http_client and not self.http_client.is_closed:
|
|
170
|
+
await self.http_client.aclose()
|
|
171
|
+
await super().cleanup()
|
autobyteus/llm/api/nvidia_llm.py
CHANGED
|
@@ -8,6 +8,7 @@ from autobyteus.llm.utils.llm_config import LLMConfig
|
|
|
8
8
|
from autobyteus.llm.utils.messages import MessageRole, Message
|
|
9
9
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
10
10
|
from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
|
|
11
|
+
from autobyteus.llm.user_message import LLMUserMessage
|
|
11
12
|
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
@@ -38,11 +39,11 @@ class NvidiaLLM(BaseLLM):
|
|
|
38
39
|
except Exception as e:
|
|
39
40
|
raise ValueError(f"Failed to initialize Nvidia client: {str(e)}")
|
|
40
41
|
|
|
41
|
-
async def _send_user_message_to_llm(self, user_message:
|
|
42
|
+
async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
|
|
42
43
|
self.add_user_message(user_message)
|
|
43
44
|
try:
|
|
44
45
|
completion = self.client.chat.completions.create(
|
|
45
|
-
model=self.model,
|
|
46
|
+
model=self.model.value,
|
|
46
47
|
messages=[msg.to_dict() for msg in self.messages],
|
|
47
48
|
temperature=0,
|
|
48
49
|
top_p=1,
|
|
@@ -65,12 +66,12 @@ class NvidiaLLM(BaseLLM):
|
|
|
65
66
|
except Exception as e:
|
|
66
67
|
raise ValueError(f"Error in Nvidia API call: {str(e)}")
|
|
67
68
|
|
|
68
|
-
async def
|
|
69
|
+
async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
|
|
69
70
|
self.add_user_message(user_message)
|
|
70
71
|
complete_response = ""
|
|
71
72
|
try:
|
|
72
73
|
completion = self.client.chat.completions.create(
|
|
73
|
-
model=self.model,
|
|
74
|
+
model=self.model.value,
|
|
74
75
|
messages=[msg.to_dict() for msg in self.messages],
|
|
75
76
|
temperature=0,
|
|
76
77
|
top_p=1,
|
|
@@ -104,4 +105,4 @@ class NvidiaLLM(BaseLLM):
|
|
|
104
105
|
raise ValueError(f"Error in Nvidia API streaming call: {str(e)}")
|
|
105
106
|
|
|
106
107
|
async def cleanup(self):
|
|
107
|
-
super().cleanup()
|
|
108
|
+
await super().cleanup()
|