autobyteus 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. autobyteus/agent/context/agent_config.py +6 -1
  2. autobyteus/agent/context/agent_runtime_state.py +7 -1
  3. autobyteus/agent/handlers/llm_user_message_ready_event_handler.py +30 -7
  4. autobyteus/agent/handlers/tool_result_event_handler.py +100 -88
  5. autobyteus/agent/handlers/user_input_message_event_handler.py +22 -25
  6. autobyteus/agent/llm_response_processor/provider_aware_tool_usage_processor.py +7 -1
  7. autobyteus/agent/message/__init__.py +7 -5
  8. autobyteus/agent/message/agent_input_user_message.py +6 -16
  9. autobyteus/agent/message/context_file.py +24 -24
  10. autobyteus/agent/message/context_file_type.py +29 -8
  11. autobyteus/agent/message/multimodal_message_builder.py +47 -0
  12. autobyteus/agent/streaming/stream_event_payloads.py +23 -4
  13. autobyteus/agent/system_prompt_processor/tool_manifest_injector_processor.py +6 -2
  14. autobyteus/agent/tool_invocation.py +27 -2
  15. autobyteus/agent_team/agent_team_builder.py +22 -1
  16. autobyteus/agent_team/bootstrap_steps/agent_configuration_preparation_step.py +9 -2
  17. autobyteus/agent_team/context/agent_team_config.py +1 -0
  18. autobyteus/agent_team/context/agent_team_runtime_state.py +0 -2
  19. autobyteus/llm/api/autobyteus_llm.py +33 -33
  20. autobyteus/llm/api/bedrock_llm.py +13 -5
  21. autobyteus/llm/api/claude_llm.py +13 -27
  22. autobyteus/llm/api/gemini_llm.py +108 -42
  23. autobyteus/llm/api/groq_llm.py +4 -3
  24. autobyteus/llm/api/mistral_llm.py +97 -51
  25. autobyteus/llm/api/nvidia_llm.py +6 -5
  26. autobyteus/llm/api/ollama_llm.py +37 -12
  27. autobyteus/llm/api/openai_compatible_llm.py +91 -91
  28. autobyteus/llm/autobyteus_provider.py +1 -1
  29. autobyteus/llm/base_llm.py +42 -139
  30. autobyteus/llm/extensions/base_extension.py +6 -6
  31. autobyteus/llm/extensions/token_usage_tracking_extension.py +3 -2
  32. autobyteus/llm/llm_factory.py +131 -61
  33. autobyteus/llm/ollama_provider_resolver.py +1 -0
  34. autobyteus/llm/providers.py +1 -0
  35. autobyteus/llm/token_counter/token_counter_factory.py +3 -1
  36. autobyteus/llm/user_message.py +43 -35
  37. autobyteus/llm/utils/llm_config.py +34 -18
  38. autobyteus/llm/utils/media_payload_formatter.py +99 -0
  39. autobyteus/llm/utils/messages.py +32 -25
  40. autobyteus/llm/utils/response_types.py +9 -3
  41. autobyteus/llm/utils/token_usage.py +6 -5
  42. autobyteus/multimedia/__init__.py +31 -0
  43. autobyteus/multimedia/audio/__init__.py +11 -0
  44. autobyteus/multimedia/audio/api/__init__.py +4 -0
  45. autobyteus/multimedia/audio/api/autobyteus_audio_client.py +59 -0
  46. autobyteus/multimedia/audio/api/gemini_audio_client.py +219 -0
  47. autobyteus/multimedia/audio/audio_client_factory.py +120 -0
  48. autobyteus/multimedia/audio/audio_model.py +97 -0
  49. autobyteus/multimedia/audio/autobyteus_audio_provider.py +108 -0
  50. autobyteus/multimedia/audio/base_audio_client.py +40 -0
  51. autobyteus/multimedia/image/__init__.py +11 -0
  52. autobyteus/multimedia/image/api/__init__.py +9 -0
  53. autobyteus/multimedia/image/api/autobyteus_image_client.py +97 -0
  54. autobyteus/multimedia/image/api/gemini_image_client.py +188 -0
  55. autobyteus/multimedia/image/api/openai_image_client.py +142 -0
  56. autobyteus/multimedia/image/autobyteus_image_provider.py +109 -0
  57. autobyteus/multimedia/image/base_image_client.py +67 -0
  58. autobyteus/multimedia/image/image_client_factory.py +118 -0
  59. autobyteus/multimedia/image/image_model.py +97 -0
  60. autobyteus/multimedia/providers.py +5 -0
  61. autobyteus/multimedia/runtimes.py +8 -0
  62. autobyteus/multimedia/utils/__init__.py +10 -0
  63. autobyteus/multimedia/utils/api_utils.py +19 -0
  64. autobyteus/multimedia/utils/multimedia_config.py +29 -0
  65. autobyteus/multimedia/utils/response_types.py +13 -0
  66. autobyteus/task_management/tools/publish_task_plan.py +4 -16
  67. autobyteus/task_management/tools/update_task_status.py +4 -19
  68. autobyteus/tools/__init__.py +5 -4
  69. autobyteus/tools/base_tool.py +98 -29
  70. autobyteus/tools/browser/standalone/__init__.py +0 -1
  71. autobyteus/tools/google_search.py +149 -0
  72. autobyteus/tools/mcp/schema_mapper.py +29 -71
  73. autobyteus/tools/multimedia/__init__.py +8 -0
  74. autobyteus/tools/multimedia/audio_tools.py +116 -0
  75. autobyteus/tools/multimedia/image_tools.py +186 -0
  76. autobyteus/tools/parameter_schema.py +82 -89
  77. autobyteus/tools/pydantic_schema_converter.py +81 -0
  78. autobyteus/tools/tool_category.py +1 -0
  79. autobyteus/tools/usage/formatters/default_json_example_formatter.py +89 -20
  80. autobyteus/tools/usage/formatters/default_xml_example_formatter.py +115 -41
  81. autobyteus/tools/usage/formatters/default_xml_schema_formatter.py +50 -20
  82. autobyteus/tools/usage/formatters/gemini_json_example_formatter.py +55 -22
  83. autobyteus/tools/usage/formatters/google_json_example_formatter.py +54 -21
  84. autobyteus/tools/usage/formatters/openai_json_example_formatter.py +53 -23
  85. autobyteus/tools/usage/parsers/default_xml_tool_usage_parser.py +270 -94
  86. autobyteus/tools/usage/parsers/provider_aware_tool_usage_parser.py +5 -2
  87. autobyteus/tools/usage/providers/tool_manifest_provider.py +43 -16
  88. autobyteus/tools/usage/registries/tool_formatting_registry.py +9 -2
  89. autobyteus/tools/usage/registries/tool_usage_parser_registry.py +9 -2
  90. autobyteus-1.1.7.dist-info/METADATA +204 -0
  91. {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/RECORD +98 -71
  92. examples/run_browser_agent.py +1 -1
  93. examples/run_google_slides_agent.py +2 -2
  94. examples/run_mcp_google_slides_client.py +1 -1
  95. examples/run_sqlite_agent.py +1 -1
  96. autobyteus/llm/utils/image_payload_formatter.py +0 -89
  97. autobyteus/tools/ask_user_input.py +0 -40
  98. autobyteus/tools/browser/standalone/factory/google_search_factory.py +0 -25
  99. autobyteus/tools/browser/standalone/google_search_ui.py +0 -126
  100. autobyteus-1.1.5.dist-info/METADATA +0 -161
  101. {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/WHEEL +0 -0
  102. {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/licenses/LICENSE +0 -0
  103. {autobyteus-1.1.5.dist-info → autobyteus-1.1.7.dist-info}/top_level.txt +0 -0
@@ -8,14 +8,14 @@ from autobyteus.llm.utils.llm_config import LLMConfig
8
8
  from autobyteus.llm.utils.messages import MessageRole, Message
9
9
  from autobyteus.llm.utils.token_usage import TokenUsage
10
10
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
11
+ from autobyteus.llm.user_message import LLMUserMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
  class ClaudeLLM(BaseLLM):
15
16
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
16
- # Provide defaults if not specified
17
17
  if model is None:
18
- model = LLMModel.CLAUDE_3_5_SONNET_API
18
+ model = LLMModel['claude-4-sonnet']
19
19
  if llm_config is None:
20
20
  llm_config = LLMConfig()
21
21
 
@@ -37,22 +37,22 @@ class ClaudeLLM(BaseLLM):
37
37
  raise ValueError(f"Failed to initialize Anthropic client: {str(e)}")
38
38
 
39
39
  def _get_non_system_messages(self) -> List[Dict]:
40
- """
41
- Returns all messages excluding system messages for Anthropic API compatibility.
42
- """
40
+ # NOTE: This will need to be updated to handle multimodal messages for Claude
43
41
  return [msg.to_dict() for msg in self.messages if msg.role != MessageRole.SYSTEM]
44
42
 
45
43
  def _create_token_usage(self, input_tokens: int, output_tokens: int) -> TokenUsage:
46
- """Convert Anthropic usage data to TokenUsage format."""
47
44
  return TokenUsage(
48
45
  prompt_tokens=input_tokens,
49
46
  completion_tokens=output_tokens,
50
47
  total_tokens=input_tokens + output_tokens
51
48
  )
52
49
 
53
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
50
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
54
51
  self.add_user_message(user_message)
55
52
 
53
+ # NOTE: This implementation does not yet support multimodal inputs for Claude.
54
+ # It will only send the text content.
55
+
56
56
  try:
57
57
  response = self.client.messages.create(
58
58
  model=self.model.value,
@@ -81,12 +81,15 @@ class ClaudeLLM(BaseLLM):
81
81
  raise ValueError(f"Error in Claude API call: {str(e)}")
82
82
 
83
83
  async def _stream_user_message_to_llm(
84
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
84
+ self, user_message: LLMUserMessage, **kwargs
85
85
  ) -> AsyncGenerator[ChunkResponse, None]:
86
86
  self.add_user_message(user_message)
87
87
  complete_response = ""
88
88
  final_message = None
89
89
 
90
+ # NOTE: This implementation does not yet support multimodal inputs for Claude.
91
+ # It will only send the text content.
92
+
90
93
  try:
91
94
  with self.client.messages.stream(
92
95
  model=self.model.value,
@@ -96,30 +99,13 @@ class ClaudeLLM(BaseLLM):
96
99
  messages=self._get_non_system_messages(),
97
100
  ) as stream:
98
101
  for event in stream:
99
- logger.debug(f"Event Received: {event.type}")
100
-
101
- if event.type == "message_start":
102
- logger.debug(f"Message Start: {event.message}")
103
-
104
- elif event.type == "content_block_start":
105
- logger.debug(f"Content Block Start at index {event.index}: {event.content_block}")
106
-
107
- elif event.type == "content_block_delta" and event.delta.type == "text_delta":
108
- logger.debug(f"Text Delta: {event.delta.text}")
102
+ if event.type == "content_block_delta" and event.delta.type == "text_delta":
109
103
  complete_response += event.delta.text
110
104
  yield ChunkResponse(
111
105
  content=event.delta.text,
112
106
  is_complete=False
113
107
  )
114
108
 
115
- elif event.type == "message_delta":
116
- logger.debug(f"Message Delta: Stop Reason - {event.delta.stop_reason}, "
117
- f"Stop Sequence - {event.delta.stop_sequence}")
118
-
119
- elif event.type == "content_block_stop":
120
- logger.debug(f"Content Block Stop at index {event.index}: {event.content_block}")
121
-
122
- # Get final message for token usage
123
109
  final_message = stream.get_final_message()
124
110
  if final_message:
125
111
  token_usage = self._create_token_usage(
@@ -140,4 +126,4 @@ class ClaudeLLM(BaseLLM):
140
126
  raise ValueError(f"Error in Claude API streaming: {str(e)}")
141
127
 
142
128
  async def cleanup(self):
143
- super().cleanup()
129
+ await super().cleanup()
@@ -1,6 +1,6 @@
1
1
  import logging
2
- from typing import Dict, Optional, List, AsyncGenerator
3
- import google.generativeai as genai
2
+ from typing import Dict, List, AsyncGenerator, Any
3
+ import google.generativeai as genai # CHANGED: Using the older 'google.generativeai' library
4
4
  import os
5
5
  from autobyteus.llm.models import LLMModel
6
6
  from autobyteus.llm.base_llm import BaseLLM
@@ -8,68 +8,93 @@ from autobyteus.llm.utils.llm_config import LLMConfig
8
8
  from autobyteus.llm.utils.messages import MessageRole, Message
9
9
  from autobyteus.llm.utils.token_usage import TokenUsage
10
10
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
11
+ from autobyteus.llm.user_message import LLMUserMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
15
+ def _format_gemini_history(messages: List[Message]) -> List[Dict[str, Any]]:
16
+ """
17
+ Formats internal message history for the Gemini API.
18
+ This function remains compatible with the older library.
19
+ """
20
+ history = []
21
+ # System message is handled separately in the model initialization
22
+ for msg in messages:
23
+ if msg.role in [MessageRole.USER, MessageRole.ASSISTANT]:
24
+ role = 'model' if msg.role == MessageRole.ASSISTANT else 'user'
25
+ history.append({"role": role, "parts": [{"text": msg.content}]})
26
+ return history
27
+
14
28
  class GeminiLLM(BaseLLM):
15
29
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
16
- self.generation_config = {
17
- "temperature": 0,
18
- "top_p": 0.95,
19
- "top_k": 64,
20
- "max_output_tokens": 8192,
21
- "response_mime_type": "text/plain",
22
- }
23
-
24
- # Provide defaults if not specified
25
30
  if model is None:
26
- model = LLMModel.GEMINI_1_5_FLASH_API
31
+ model = LLMModel['gemini-2.5-flash'] # Note: Ensure model name is compatible, e.g., 'gemini-1.5-flash-latest'
27
32
  if llm_config is None:
28
33
  llm_config = LLMConfig()
29
-
34
+
30
35
  super().__init__(model=model, llm_config=llm_config)
31
- self.client = self.initialize()
32
- self.chat_session = None
36
+
37
+ # CHANGED: Initialization flow. Configure API key and then instantiate the model.
38
+ self.initialize()
39
+
40
+ system_instruction = self.system_message if self.system_message else None
41
+
42
+ self.model = genai.GenerativeModel(
43
+ model_name=self.model.value,
44
+ system_instruction=system_instruction
45
+ )
33
46
 
34
- @classmethod
35
- def initialize(cls):
47
+ @staticmethod
48
+ def initialize():
49
+ """
50
+ CHANGED: This method now configures the genai library with the API key
51
+ instead of creating a client instance.
52
+ """
36
53
  api_key = os.environ.get("GEMINI_API_KEY")
37
54
  if not api_key:
38
55
  logger.error("GEMINI_API_KEY environment variable is not set.")
39
- raise ValueError(
40
- "GEMINI_API_KEY environment variable is not set. "
41
- "Please set this variable in your environment."
42
- )
56
+ raise ValueError("GEMINI_API_KEY environment variable is not set.")
43
57
  try:
44
58
  genai.configure(api_key=api_key)
45
- return genai
46
59
  except Exception as e:
47
- logger.error(f"Failed to initialize Gemini client: {str(e)}")
48
- raise ValueError(f"Failed to initialize Gemini client: {str(e)}")
60
+ logger.error(f"Failed to configure Gemini client: {str(e)}")
61
+ raise ValueError(f"Failed to configure Gemini client: {str(e)}")
49
62
 
50
- def _ensure_chat_session(self):
51
- if not self.chat_session:
52
- model = self.client.GenerativeModel(
53
- model_name=self.model.value,
54
- generation_config=self.generation_config
55
- )
56
- history = []
57
- for msg in self.messages:
58
- history.append({"role": msg.role.value, "parts": [msg.content]})
59
- self.chat_session = model.start_chat(history=history)
63
+ def _get_generation_config(self) -> Dict[str, Any]:
64
+ """
65
+ CHANGED: Builds the generation config as a dictionary.
66
+ 'thinking_config' is not available in the old library.
67
+ 'system_instruction' is passed during model initialization.
68
+ """
69
+ # Basic configuration, you can expand this with temperature, top_p, etc.
70
+ # from self.llm_config if needed.
71
+ config = {
72
+ "response_mime_type": "text/plain",
73
+ # Example: "temperature": self.llm_config.temperature
74
+ }
75
+ return config
60
76
 
61
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
77
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
62
78
  self.add_user_message(user_message)
79
+
63
80
  try:
64
- self._ensure_chat_session()
65
- response = self.chat_session.send_message(user_message)
81
+ history = _format_gemini_history(self.messages)
82
+ generation_config = self._get_generation_config()
83
+
84
+ # CHANGED: API call now uses the model instance directly.
85
+ response = await self.model.generate_content_async(
86
+ contents=history,
87
+ generation_config=generation_config,
88
+ )
89
+
66
90
  assistant_message = response.text
67
91
  self.add_assistant_message(assistant_message)
68
92
 
93
+ # CHANGED: Token usage is extracted from 'usage_metadata'.
69
94
  token_usage = TokenUsage(
70
- prompt_tokens=0,
71
- completion_tokens=0,
72
- total_tokens=0
95
+ prompt_tokens=response.usage_metadata.prompt_token_count,
96
+ completion_tokens=response.usage_metadata.candidates_token_count,
97
+ total_tokens=response.usage_metadata.total_token_count
73
98
  )
74
99
 
75
100
  return CompleteResponse(
@@ -80,6 +105,47 @@ class GeminiLLM(BaseLLM):
80
105
  logger.error(f"Error in Gemini API call: {str(e)}")
81
106
  raise ValueError(f"Error in Gemini API call: {str(e)}")
82
107
 
108
+ async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
109
+ self.add_user_message(user_message)
110
+ complete_response = ""
111
+
112
+ try:
113
+ history = _format_gemini_history(self.messages)
114
+ generation_config = self._get_generation_config()
115
+
116
+ # CHANGED: API call for streaming is now part of generate_content_async.
117
+ response_stream = await self.model.generate_content_async(
118
+ contents=history,
119
+ generation_config=generation_config,
120
+ stream=True
121
+ )
122
+
123
+ async for chunk in response_stream:
124
+ chunk_text = chunk.text
125
+ complete_response += chunk_text
126
+ yield ChunkResponse(
127
+ content=chunk_text,
128
+ is_complete=False
129
+ )
130
+
131
+ self.add_assistant_message(complete_response)
132
+
133
+ # NOTE: The old library's async stream does not easily expose token usage.
134
+ # Keeping it at 0, consistent with your original implementation.
135
+ token_usage = TokenUsage(
136
+ prompt_tokens=0,
137
+ completion_tokens=0,
138
+ total_tokens=0
139
+ )
140
+
141
+ yield ChunkResponse(
142
+ content="",
143
+ is_complete=True,
144
+ usage=token_usage
145
+ )
146
+ except Exception as e:
147
+ logger.error(f"Error in Gemini API streaming call: {str(e)}")
148
+ raise ValueError(f"Error in Gemini API streaming call: {str(e)}")
149
+
83
150
  async def cleanup(self):
84
- self.chat_session = None
85
- super().cleanup()
151
+ await super().cleanup()
@@ -7,6 +7,7 @@ from autobyteus.llm.utils.llm_config import LLMConfig
7
7
  from autobyteus.llm.utils.messages import MessageRole, Message
8
8
  from autobyteus.llm.utils.token_usage import TokenUsage
9
9
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
10
+ from autobyteus.llm.user_message import LLMUserMessage
10
11
 
11
12
  logger = logging.getLogger(__name__)
12
13
 
@@ -36,7 +37,7 @@ class GroqLLM(BaseLLM):
36
37
  except Exception as e:
37
38
  raise ValueError(f"Failed to initialize Groq client: {str(e)}")
38
39
 
39
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
40
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
40
41
  self.add_user_message(user_message)
41
42
  try:
42
43
  # Placeholder for sending message to Groq API
@@ -58,7 +59,7 @@ class GroqLLM(BaseLLM):
58
59
  raise ValueError(f"Error in Groq API call: {str(e)}")
59
60
 
60
61
  async def _stream_user_message_to_llm(
61
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
62
+ self, user_message: LLMUserMessage, **kwargs
62
63
  ) -> AsyncGenerator[ChunkResponse, None]:
63
64
  self.add_user_message(user_message)
64
65
  complete_response = ""
@@ -90,4 +91,4 @@ class GroqLLM(BaseLLM):
90
91
  raise ValueError(f"Error in Groq API streaming: {str(e)}")
91
92
 
92
93
  async def cleanup(self):
93
- super().cleanup()
94
+ await super().cleanup()
@@ -1,45 +1,91 @@
1
- from typing import Dict, Optional, List, AsyncGenerator
1
+ from typing import Dict, Optional, List, Any, AsyncGenerator, Union
2
2
  import os
3
3
  import logging
4
+ import httpx
5
+ import asyncio
4
6
  from autobyteus.llm.models import LLMModel
5
7
  from autobyteus.llm.base_llm import BaseLLM
6
8
  from mistralai import Mistral
7
- from autobyteus.llm.utils.messages import MessageRole, Message
9
+ from autobyteus.llm.utils.messages import Message, MessageRole
8
10
  from autobyteus.llm.utils.llm_config import LLMConfig
9
11
  from autobyteus.llm.utils.token_usage import TokenUsage
10
12
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
13
+ from autobyteus.llm.user_message import LLMUserMessage
14
+ from autobyteus.llm.utils.media_payload_formatter import image_source_to_base64, get_mime_type, is_valid_image_path
11
15
 
12
- # Configure logger
13
16
  logger = logging.getLogger(__name__)
14
17
 
18
+ async def _format_mistral_messages(messages: List[Message]) -> List[Dict[str, Any]]:
19
+ """Formats a list of internal Message objects into a list of dictionaries for the Mistral API."""
20
+ mistral_messages = []
21
+ for msg in messages:
22
+ # Skip empty messages from non-system roles as Mistral API may reject them
23
+ if not msg.content and not msg.image_urls and msg.role != MessageRole.SYSTEM:
24
+ continue
25
+
26
+ content: Union[str, List[Dict[str, Any]]]
27
+
28
+ if msg.image_urls:
29
+ content_parts: List[Dict[str, Any]] = []
30
+ if msg.content:
31
+ content_parts.append({"type": "text", "text": msg.content})
32
+
33
+ image_tasks = [image_source_to_base64(url) for url in msg.image_urls]
34
+ try:
35
+ base64_images = await asyncio.gather(*image_tasks)
36
+ for i, b64_image in enumerate(base64_images):
37
+ original_url = msg.image_urls[i]
38
+ mime_type = get_mime_type(original_url) if is_valid_image_path(original_url) else "image/jpeg"
39
+ data_uri = f"data:{mime_type};base64,{b64_image}"
40
+
41
+ # Mistral's format for image parts
42
+ content_parts.append({
43
+ "type": "image_url",
44
+ "image_url": {
45
+ "url": data_uri
46
+ }
47
+ })
48
+ except Exception as e:
49
+ logger.error(f"Error processing images for Mistral: {e}")
50
+
51
+ if msg.audio_urls:
52
+ logger.warning("MistralLLM does not yet support audio; skipping.")
53
+ if msg.video_urls:
54
+ logger.warning("MistralLLM does not yet support video; skipping.")
55
+
56
+ content = content_parts
57
+ else:
58
+ content = msg.content or ""
59
+
60
+ mistral_messages.append({"role": msg.role.value, "content": content})
61
+
62
+ return mistral_messages
63
+
64
+
15
65
  class MistralLLM(BaseLLM):
16
66
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
17
- # Provide defaults if not specified
18
67
  if model is None:
19
- model = LLMModel.mistral_large
68
+ model = LLMModel['mistral-large']
20
69
  if llm_config is None:
21
70
  llm_config = LLMConfig()
22
71
 
23
72
  super().__init__(model=model, llm_config=llm_config)
24
- self.client = self.initialize()
73
+ self.http_client = httpx.AsyncClient()
74
+ self.client: Mistral = self._initialize()
25
75
  logger.info(f"MistralLLM initialized with model: {self.model}")
26
76
 
27
- @classmethod
28
- def initialize(cls):
77
+ def _initialize(self) -> Mistral:
29
78
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
30
79
  if not mistral_api_key:
31
80
  logger.error("MISTRAL_API_KEY environment variable is not set")
32
- raise ValueError(
33
- "MISTRAL_API_KEY environment variable is not set. "
34
- "Please set this variable in your environment."
35
- )
81
+ raise ValueError("MISTRAL_API_KEY environment variable is not set.")
36
82
  try:
37
- return Mistral(api_key=mistral_api_key)
83
+ return Mistral(api_key=mistral_api_key, client=self.http_client)
38
84
  except Exception as e:
39
85
  logger.error(f"Failed to initialize Mistral client: {str(e)}")
40
86
  raise ValueError(f"Failed to initialize Mistral client: {str(e)}")
41
87
 
42
- def _create_token_usage(self, usage_data: Dict) -> TokenUsage:
88
+ def _create_token_usage(self, usage_data: Any) -> TokenUsage:
43
89
  """Convert Mistral usage data to TokenUsage format."""
44
90
  return TokenUsage(
45
91
  prompt_tokens=usage_data.prompt_tokens,
@@ -48,26 +94,26 @@ class MistralLLM(BaseLLM):
48
94
  )
49
95
 
50
96
  async def _send_user_message_to_llm(
51
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
97
+ self, user_message: LLMUserMessage, **kwargs
52
98
  ) -> CompleteResponse:
53
99
  self.add_user_message(user_message)
54
-
100
+
55
101
  try:
56
- mistral_messages = [msg.to_mistral_message() for msg in self.messages]
102
+ mistral_messages = await _format_mistral_messages(self.messages)
57
103
 
58
- chat_response = self.client.chat.complete(
104
+ chat_response = await self.client.chat.complete_async(
59
105
  model=self.model.value,
60
106
  messages=mistral_messages,
107
+ temperature=self.config.temperature,
108
+ max_tokens=self.config.max_tokens,
109
+ top_p=self.config.top_p,
61
110
  )
62
111
 
63
- assistant_message = chat_response.choices.message.content
112
+ assistant_message = chat_response.choices[0].message.content
64
113
  self.add_assistant_message(assistant_message)
65
114
 
66
- # Create token usage if available
67
- token_usage = None
68
- if hasattr(chat_response, 'usage') and chat_response.usage:
69
- token_usage = self._create_token_usage(chat_response.usage)
70
- logger.debug(f"Token usage recorded: {token_usage}")
115
+ token_usage = self._create_token_usage(chat_response.usage)
116
+ logger.debug(f"Token usage recorded: {token_usage}")
71
117
 
72
118
  return CompleteResponse(
73
119
  content=assistant_message,
@@ -78,48 +124,48 @@ class MistralLLM(BaseLLM):
78
124
  raise ValueError(f"Error in Mistral API call: {str(e)}")
79
125
 
80
126
  async def _stream_user_message_to_llm(
81
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
127
+ self, user_message: LLMUserMessage, **kwargs
82
128
  ) -> AsyncGenerator[ChunkResponse, None]:
83
129
  self.add_user_message(user_message)
84
130
 
131
+ accumulated_message = ""
132
+ final_usage = None
133
+
85
134
  try:
86
- mistral_messages = [msg.to_mistral_message() for msg in self.messages]
87
-
88
- stream = await self.client.chat.stream_async(
135
+ mistral_messages = await _format_mistral_messages(self.messages)
136
+
137
+ stream = self.client.chat.stream_async(
89
138
  model=self.model.value,
90
139
  messages=mistral_messages,
140
+ temperature=self.config.temperature,
141
+ max_tokens=self.config.max_tokens,
142
+ top_p=self.config.top_p,
91
143
  )
92
144
 
93
- accumulated_message = ""
94
-
95
145
  async for chunk in stream:
96
- if chunk.data.choices.delta.content is not None:
97
- token = chunk.data.choices.delta.content
146
+ if chunk.choices and chunk.choices[0].delta.content is not None:
147
+ token = chunk.choices[0].delta.content
98
148
  accumulated_message += token
99
149
 
100
- # For intermediate chunks, yield without usage
101
- yield ChunkResponse(
102
- content=token,
103
- is_complete=False
104
- )
105
-
106
- # Check if this is the last chunk with usage data
107
- if hasattr(chunk.data, 'usage') and chunk.data.usage is not None:
108
- token_usage = self._create_token_usage(chunk.data.usage)
109
- yield ChunkResponse(
110
- content="",
111
- is_complete=True,
112
- usage=token_usage
113
- )
114
-
115
- # After streaming is complete, store the full message
150
+ yield ChunkResponse(content=token, is_complete=False)
151
+
152
+ if hasattr(chunk, 'usage') and chunk.usage:
153
+ final_usage = self._create_token_usage(chunk.usage)
154
+
155
+ # Yield the final chunk with usage data
156
+ yield ChunkResponse(
157
+ content="",
158
+ is_complete=True,
159
+ usage=final_usage
160
+ )
161
+
116
162
  self.add_assistant_message(accumulated_message)
117
163
  except Exception as e:
118
164
  logger.error(f"Error in Mistral API streaming call: {str(e)}")
119
165
  raise ValueError(f"Error in Mistral API streaming call: {str(e)}")
120
166
 
121
167
  async def cleanup(self):
122
- # Clean up any resources if needed
123
168
  logger.debug("Cleaning up MistralLLM instance")
124
- self.messages = []
125
- super().cleanup()
169
+ if self.http_client and not self.http_client.is_closed:
170
+ await self.http_client.aclose()
171
+ await super().cleanup()
@@ -8,6 +8,7 @@ from autobyteus.llm.utils.llm_config import LLMConfig
8
8
  from autobyteus.llm.utils.messages import MessageRole, Message
9
9
  from autobyteus.llm.utils.token_usage import TokenUsage
10
10
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
11
+ from autobyteus.llm.user_message import LLMUserMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
@@ -38,11 +39,11 @@ class NvidiaLLM(BaseLLM):
38
39
  except Exception as e:
39
40
  raise ValueError(f"Failed to initialize Nvidia client: {str(e)}")
40
41
 
41
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
42
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
42
43
  self.add_user_message(user_message)
43
44
  try:
44
45
  completion = self.client.chat.completions.create(
45
- model=self.model,
46
+ model=self.model.value,
46
47
  messages=[msg.to_dict() for msg in self.messages],
47
48
  temperature=0,
48
49
  top_p=1,
@@ -65,12 +66,12 @@ class NvidiaLLM(BaseLLM):
65
66
  except Exception as e:
66
67
  raise ValueError(f"Error in Nvidia API call: {str(e)}")
67
68
 
68
- async def stream_response(self, user_message: str) -> AsyncGenerator[ChunkResponse, None]:
69
+ async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
69
70
  self.add_user_message(user_message)
70
71
  complete_response = ""
71
72
  try:
72
73
  completion = self.client.chat.completions.create(
73
- model=self.model,
74
+ model=self.model.value,
74
75
  messages=[msg.to_dict() for msg in self.messages],
75
76
  temperature=0,
76
77
  top_p=1,
@@ -104,4 +105,4 @@ class NvidiaLLM(BaseLLM):
104
105
  raise ValueError(f"Error in Nvidia API streaming call: {str(e)}")
105
106
 
106
107
  async def cleanup(self):
107
- super().cleanup()
108
+ await super().cleanup()