autobyteus 1.1.5__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. autobyteus/agent/context/agent_config.py +6 -1
  2. autobyteus/agent/handlers/llm_user_message_ready_event_handler.py +30 -7
  3. autobyteus/agent/handlers/user_input_message_event_handler.py +22 -25
  4. autobyteus/agent/message/__init__.py +7 -5
  5. autobyteus/agent/message/agent_input_user_message.py +6 -16
  6. autobyteus/agent/message/context_file.py +24 -24
  7. autobyteus/agent/message/context_file_type.py +29 -8
  8. autobyteus/agent/message/multimodal_message_builder.py +47 -0
  9. autobyteus/agent/streaming/stream_event_payloads.py +23 -4
  10. autobyteus/agent/system_prompt_processor/tool_manifest_injector_processor.py +6 -2
  11. autobyteus/agent/tool_invocation.py +2 -1
  12. autobyteus/agent_team/bootstrap_steps/agent_configuration_preparation_step.py +9 -2
  13. autobyteus/agent_team/context/agent_team_config.py +1 -0
  14. autobyteus/llm/api/autobyteus_llm.py +33 -33
  15. autobyteus/llm/api/bedrock_llm.py +13 -5
  16. autobyteus/llm/api/claude_llm.py +13 -27
  17. autobyteus/llm/api/gemini_llm.py +108 -42
  18. autobyteus/llm/api/groq_llm.py +4 -3
  19. autobyteus/llm/api/mistral_llm.py +97 -51
  20. autobyteus/llm/api/nvidia_llm.py +6 -5
  21. autobyteus/llm/api/ollama_llm.py +37 -12
  22. autobyteus/llm/api/openai_compatible_llm.py +91 -91
  23. autobyteus/llm/autobyteus_provider.py +1 -1
  24. autobyteus/llm/base_llm.py +42 -139
  25. autobyteus/llm/extensions/base_extension.py +6 -6
  26. autobyteus/llm/extensions/token_usage_tracking_extension.py +3 -2
  27. autobyteus/llm/llm_factory.py +106 -4
  28. autobyteus/llm/token_counter/token_counter_factory.py +1 -1
  29. autobyteus/llm/user_message.py +43 -35
  30. autobyteus/llm/utils/llm_config.py +34 -18
  31. autobyteus/llm/utils/media_payload_formatter.py +99 -0
  32. autobyteus/llm/utils/messages.py +32 -25
  33. autobyteus/llm/utils/response_types.py +9 -3
  34. autobyteus/llm/utils/token_usage.py +6 -5
  35. autobyteus/multimedia/__init__.py +31 -0
  36. autobyteus/multimedia/audio/__init__.py +11 -0
  37. autobyteus/multimedia/audio/api/__init__.py +4 -0
  38. autobyteus/multimedia/audio/api/autobyteus_audio_client.py +59 -0
  39. autobyteus/multimedia/audio/api/gemini_audio_client.py +219 -0
  40. autobyteus/multimedia/audio/audio_client_factory.py +120 -0
  41. autobyteus/multimedia/audio/audio_model.py +96 -0
  42. autobyteus/multimedia/audio/autobyteus_audio_provider.py +108 -0
  43. autobyteus/multimedia/audio/base_audio_client.py +40 -0
  44. autobyteus/multimedia/image/__init__.py +11 -0
  45. autobyteus/multimedia/image/api/__init__.py +9 -0
  46. autobyteus/multimedia/image/api/autobyteus_image_client.py +97 -0
  47. autobyteus/multimedia/image/api/gemini_image_client.py +188 -0
  48. autobyteus/multimedia/image/api/openai_image_client.py +142 -0
  49. autobyteus/multimedia/image/autobyteus_image_provider.py +109 -0
  50. autobyteus/multimedia/image/base_image_client.py +67 -0
  51. autobyteus/multimedia/image/image_client_factory.py +118 -0
  52. autobyteus/multimedia/image/image_model.py +96 -0
  53. autobyteus/multimedia/providers.py +5 -0
  54. autobyteus/multimedia/runtimes.py +8 -0
  55. autobyteus/multimedia/utils/__init__.py +10 -0
  56. autobyteus/multimedia/utils/api_utils.py +19 -0
  57. autobyteus/multimedia/utils/multimedia_config.py +29 -0
  58. autobyteus/multimedia/utils/response_types.py +13 -0
  59. autobyteus/tools/__init__.py +3 -0
  60. autobyteus/tools/multimedia/__init__.py +8 -0
  61. autobyteus/tools/multimedia/audio_tools.py +116 -0
  62. autobyteus/tools/multimedia/image_tools.py +186 -0
  63. autobyteus/tools/tool_category.py +1 -0
  64. autobyteus/tools/usage/parsers/provider_aware_tool_usage_parser.py +5 -2
  65. autobyteus/tools/usage/providers/tool_manifest_provider.py +5 -3
  66. autobyteus/tools/usage/registries/tool_formatting_registry.py +9 -2
  67. autobyteus/tools/usage/registries/tool_usage_parser_registry.py +9 -2
  68. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/METADATA +9 -9
  69. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/RECORD +73 -45
  70. examples/run_browser_agent.py +1 -1
  71. autobyteus/llm/utils/image_payload_formatter.py +0 -89
  72. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/WHEEL +0 -0
  73. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/licenses/LICENSE +0 -0
  74. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,8 @@ logger = logging.getLogger(__name__)
15
15
  class AgentConfigurationPreparationStep(BaseAgentTeamBootstrapStep):
16
16
  """
17
17
  Bootstrap step to prepare the final, immutable configuration for every
18
- agent in the team. It injects team-specific context and applies the final
19
- coordinator prompt. It no longer injects tools.
18
+ agent in the team. It injects team-specific context, applies team-level
19
+ settings like tool format overrides, and prepares the final coordinator prompt.
20
20
  """
21
21
  async def execute(self, context: 'AgentTeamContext', phase_manager: 'AgentTeamPhaseManager') -> bool:
22
22
  team_id = context.team_id
@@ -44,6 +44,13 @@ class AgentConfigurationPreparationStep(BaseAgentTeamBootstrapStep):
44
44
 
45
45
  final_config = node_definition.copy()
46
46
 
47
+ # --- Team-level Setting Propagation ---
48
+ # If the team config specifies a tool format, it overrides any agent-level setting.
49
+ if context.config.use_xml_tool_format is not None:
50
+ final_config.use_xml_tool_format = context.config.use_xml_tool_format
51
+ logger.debug(f"Team '{team_id}': Applied team-level use_xml_tool_format={final_config.use_xml_tool_format} to agent '{unique_name}'.")
52
+
53
+
47
54
  # --- Shared Context Injection ---
48
55
  # The shared context is injected into the initial_custom_data dictionary,
49
56
  # which is then used by the AgentFactory to create the AgentRuntimeState.
@@ -20,6 +20,7 @@ class AgentTeamConfig:
20
20
  coordinator_node: TeamNodeConfig
21
21
  role: Optional[str] = None
22
22
  task_notification_mode: TaskNotificationMode = TaskNotificationMode.AGENT_MANUAL_NOTIFICATION
23
+ use_xml_tool_format: Optional[bool] = None
23
24
 
24
25
  def __post_init__(self):
25
26
  if not self.name or not isinstance(self.name, str):
@@ -4,6 +4,7 @@ from autobyteus.llm.models import LLMModel
4
4
  from autobyteus.llm.utils.llm_config import LLMConfig
5
5
  from autobyteus.llm.utils.token_usage import TokenUsage
6
6
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
7
+ from autobyteus.llm.user_message import LLMUserMessage
7
8
  from autobyteus_llm_client.client import AutobyteusClient
8
9
  import logging
9
10
  import uuid
@@ -12,36 +13,35 @@ logger = logging.getLogger(__name__)
12
13
 
13
14
  class AutobyteusLLM(BaseLLM):
14
15
  def __init__(self, model: LLMModel, llm_config: LLMConfig):
15
- # The host URL is now passed via the model object.
16
16
  if not model.host_url:
17
17
  raise ValueError("AutobyteusLLM requires a host_url to be set in its LLMModel object.")
18
18
 
19
19
  super().__init__(model=model, llm_config=llm_config)
20
20
 
21
- # Instantiate the client with the specific host for this model.
22
21
  self.client = AutobyteusClient(server_url=self.model.host_url)
23
22
  self.conversation_id = str(uuid.uuid4())
24
23
  logger.info(f"AutobyteusLLM initialized for model '{self.model.model_identifier}' with conversation ID: {self.conversation_id}")
25
24
 
26
25
  async def _send_user_message_to_llm(
27
26
  self,
28
- user_message: str,
29
- image_urls: Optional[List[str]] = None,
27
+ user_message: LLMUserMessage,
30
28
  **kwargs
31
29
  ) -> CompleteResponse:
32
30
  self.add_user_message(user_message)
33
31
  try:
34
32
  response = await self.client.send_message(
35
33
  conversation_id=self.conversation_id,
36
- model_name=self.model.name, # Use `name` as it's the original model name for the API
37
- user_message=user_message,
38
- image_urls=image_urls
34
+ model_name=self.model.name,
35
+ user_message=user_message.content,
36
+ image_urls=user_message.image_urls,
37
+ audio_urls=user_message.audio_urls,
38
+ video_urls=user_message.video_urls
39
39
  )
40
40
 
41
41
  assistant_message = response['response']
42
42
  self.add_assistant_message(assistant_message)
43
43
 
44
- token_usage_data = response.get('token_usage', {})
44
+ token_usage_data = response.get('token_usage') or {}
45
45
  token_usage = TokenUsage(
46
46
  prompt_tokens=token_usage_data.get('prompt_tokens', 0),
47
47
  completion_tokens=token_usage_data.get('completion_tokens', 0),
@@ -59,8 +59,7 @@ class AutobyteusLLM(BaseLLM):
59
59
 
60
60
  async def _stream_user_message_to_llm(
61
61
  self,
62
- user_message: str,
63
- image_urls: Optional[List[str]] = None,
62
+ user_message: LLMUserMessage,
64
63
  **kwargs
65
64
  ) -> AsyncGenerator[ChunkResponse, None]:
66
65
  self.add_user_message(user_message)
@@ -69,36 +68,38 @@ class AutobyteusLLM(BaseLLM):
69
68
  try:
70
69
  async for chunk in self.client.stream_message(
71
70
  conversation_id=self.conversation_id,
72
- model_name=self.model.name, # Use `name` for the API call
73
- user_message=user_message,
74
- image_urls=image_urls
71
+ model_name=self.model.name,
72
+ user_message=user_message.content,
73
+ image_urls=user_message.image_urls,
74
+ audio_urls=user_message.audio_urls,
75
+ video_urls=user_message.video_urls
75
76
  ):
76
77
  if 'error' in chunk:
77
78
  raise RuntimeError(chunk['error'])
78
79
 
79
80
  content = chunk.get('content', '')
80
- complete_response += content
81
+ if content:
82
+ complete_response += content
83
+
81
84
  is_complete = chunk.get('is_complete', False)
82
-
83
- # If this is the final chunk, include token usage
85
+ token_usage = None
84
86
  if is_complete:
85
- token_usage = None
86
- if chunk.get('token_usage'):
87
- token_usage = TokenUsage(
88
- prompt_tokens=chunk['token_usage'].get('prompt_tokens', 0),
89
- completion_tokens=chunk['token_usage'].get('completion_tokens', 0),
90
- total_tokens=chunk['token_usage'].get('total_tokens', 0)
91
- )
92
- yield ChunkResponse(
93
- content=content,
94
- is_complete=True,
95
- usage=token_usage
96
- )
97
- else:
98
- yield ChunkResponse(
99
- content=content,
100
- is_complete=False
87
+ token_usage_data = chunk.get('token_usage') or {}
88
+ token_usage = TokenUsage(
89
+ prompt_tokens=token_usage_data.get('prompt_tokens', 0),
90
+ completion_tokens=token_usage_data.get('completion_tokens', 0),
91
+ total_tokens=token_usage_data.get('total_tokens', 0)
101
92
  )
93
+
94
+ yield ChunkResponse(
95
+ content=content,
96
+ reasoning=chunk.get('reasoning'),
97
+ is_complete=is_complete,
98
+ image_urls=chunk.get('image_urls', []),
99
+ audio_urls=chunk.get('audio_urls', []),
100
+ video_urls=chunk.get('video_urls', []),
101
+ usage=token_usage
102
+ )
102
103
 
103
104
  self.add_assistant_message(complete_response)
104
105
  except Exception as e:
@@ -116,7 +117,6 @@ class AutobyteusLLM(BaseLLM):
116
117
  await self.client.close()
117
118
 
118
119
  async def _handle_error_cleanup(self):
119
- """Handle cleanup operations after errors"""
120
120
  try:
121
121
  await self.cleanup()
122
122
  except Exception as cleanup_error:
@@ -9,10 +9,10 @@ from autobyteus.llm.utils.llm_config import LLMConfig
9
9
  from autobyteus.llm.utils.messages import MessageRole, Message
10
10
  from autobyteus.llm.utils.token_usage import TokenUsage
11
11
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
12
+ from autobyteus.llm.user_message import LLMUserMessage
12
13
 
13
14
  class BedrockLLM(BaseLLM):
14
15
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
15
- # Provide defaults if not specified
16
16
  if model is None:
17
17
  model = LLMModel.BEDROCK_CLAUDE_3_5_SONNET_API
18
18
  if llm_config is None:
@@ -43,14 +43,17 @@ class BedrockLLM(BaseLLM):
43
43
  except Exception as e:
44
44
  raise ValueError(f"Failed to initialize Bedrock client: {str(e)}")
45
45
 
46
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
46
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
47
47
  self.add_user_message(user_message)
48
48
 
49
+ # NOTE: This implementation does not yet support multimodal inputs for Bedrock.
50
+ # It will only send the text content.
51
+
49
52
  request_body = json.dumps({
50
53
  "anthropic_version": "bedrock-2023-05-31",
51
54
  "max_tokens": 1000,
52
55
  "temperature": 0,
53
- "messages": [msg.to_dict() for msg in self.messages],
56
+ "messages": [msg.to_dict() for msg in self.messages if msg.role != MessageRole.SYSTEM],
54
57
  "system": self.system_message if self.system_message else ""
55
58
  })
56
59
 
@@ -79,6 +82,11 @@ class BedrockLLM(BaseLLM):
79
82
  raise ValueError(f"Bedrock API error: {error_code} - {error_message}")
80
83
  except Exception as e:
81
84
  raise ValueError(f"Error in Bedrock API call: {str(e)}")
82
-
85
+
86
+ async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
87
+ # Placeholder for future implementation
88
+ response = await self._send_user_message_to_llm(user_message, **kwargs)
89
+ yield ChunkResponse(content=response.content, is_complete=True, usage=response.usage)
90
+
83
91
  async def cleanup(self):
84
- super().cleanup()
92
+ await super().cleanup()
@@ -8,14 +8,14 @@ from autobyteus.llm.utils.llm_config import LLMConfig
8
8
  from autobyteus.llm.utils.messages import MessageRole, Message
9
9
  from autobyteus.llm.utils.token_usage import TokenUsage
10
10
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
11
+ from autobyteus.llm.user_message import LLMUserMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
  class ClaudeLLM(BaseLLM):
15
16
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
16
- # Provide defaults if not specified
17
17
  if model is None:
18
- model = LLMModel.CLAUDE_3_5_SONNET_API
18
+ model = LLMModel['claude-4-sonnet']
19
19
  if llm_config is None:
20
20
  llm_config = LLMConfig()
21
21
 
@@ -37,22 +37,22 @@ class ClaudeLLM(BaseLLM):
37
37
  raise ValueError(f"Failed to initialize Anthropic client: {str(e)}")
38
38
 
39
39
  def _get_non_system_messages(self) -> List[Dict]:
40
- """
41
- Returns all messages excluding system messages for Anthropic API compatibility.
42
- """
40
+ # NOTE: This will need to be updated to handle multimodal messages for Claude
43
41
  return [msg.to_dict() for msg in self.messages if msg.role != MessageRole.SYSTEM]
44
42
 
45
43
  def _create_token_usage(self, input_tokens: int, output_tokens: int) -> TokenUsage:
46
- """Convert Anthropic usage data to TokenUsage format."""
47
44
  return TokenUsage(
48
45
  prompt_tokens=input_tokens,
49
46
  completion_tokens=output_tokens,
50
47
  total_tokens=input_tokens + output_tokens
51
48
  )
52
49
 
53
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
50
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
54
51
  self.add_user_message(user_message)
55
52
 
53
+ # NOTE: This implementation does not yet support multimodal inputs for Claude.
54
+ # It will only send the text content.
55
+
56
56
  try:
57
57
  response = self.client.messages.create(
58
58
  model=self.model.value,
@@ -81,12 +81,15 @@ class ClaudeLLM(BaseLLM):
81
81
  raise ValueError(f"Error in Claude API call: {str(e)}")
82
82
 
83
83
  async def _stream_user_message_to_llm(
84
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
84
+ self, user_message: LLMUserMessage, **kwargs
85
85
  ) -> AsyncGenerator[ChunkResponse, None]:
86
86
  self.add_user_message(user_message)
87
87
  complete_response = ""
88
88
  final_message = None
89
89
 
90
+ # NOTE: This implementation does not yet support multimodal inputs for Claude.
91
+ # It will only send the text content.
92
+
90
93
  try:
91
94
  with self.client.messages.stream(
92
95
  model=self.model.value,
@@ -96,30 +99,13 @@ class ClaudeLLM(BaseLLM):
96
99
  messages=self._get_non_system_messages(),
97
100
  ) as stream:
98
101
  for event in stream:
99
- logger.debug(f"Event Received: {event.type}")
100
-
101
- if event.type == "message_start":
102
- logger.debug(f"Message Start: {event.message}")
103
-
104
- elif event.type == "content_block_start":
105
- logger.debug(f"Content Block Start at index {event.index}: {event.content_block}")
106
-
107
- elif event.type == "content_block_delta" and event.delta.type == "text_delta":
108
- logger.debug(f"Text Delta: {event.delta.text}")
102
+ if event.type == "content_block_delta" and event.delta.type == "text_delta":
109
103
  complete_response += event.delta.text
110
104
  yield ChunkResponse(
111
105
  content=event.delta.text,
112
106
  is_complete=False
113
107
  )
114
108
 
115
- elif event.type == "message_delta":
116
- logger.debug(f"Message Delta: Stop Reason - {event.delta.stop_reason}, "
117
- f"Stop Sequence - {event.delta.stop_sequence}")
118
-
119
- elif event.type == "content_block_stop":
120
- logger.debug(f"Content Block Stop at index {event.index}: {event.content_block}")
121
-
122
- # Get final message for token usage
123
109
  final_message = stream.get_final_message()
124
110
  if final_message:
125
111
  token_usage = self._create_token_usage(
@@ -140,4 +126,4 @@ class ClaudeLLM(BaseLLM):
140
126
  raise ValueError(f"Error in Claude API streaming: {str(e)}")
141
127
 
142
128
  async def cleanup(self):
143
- super().cleanup()
129
+ await super().cleanup()
@@ -1,6 +1,6 @@
1
1
  import logging
2
- from typing import Dict, Optional, List, AsyncGenerator
3
- import google.generativeai as genai
2
+ from typing import Dict, List, AsyncGenerator, Any
3
+ import google.generativeai as genai # CHANGED: Using the older 'google.generativeai' library
4
4
  import os
5
5
  from autobyteus.llm.models import LLMModel
6
6
  from autobyteus.llm.base_llm import BaseLLM
@@ -8,68 +8,93 @@ from autobyteus.llm.utils.llm_config import LLMConfig
8
8
  from autobyteus.llm.utils.messages import MessageRole, Message
9
9
  from autobyteus.llm.utils.token_usage import TokenUsage
10
10
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
11
+ from autobyteus.llm.user_message import LLMUserMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
15
+ def _format_gemini_history(messages: List[Message]) -> List[Dict[str, Any]]:
16
+ """
17
+ Formats internal message history for the Gemini API.
18
+ This function remains compatible with the older library.
19
+ """
20
+ history = []
21
+ # System message is handled separately in the model initialization
22
+ for msg in messages:
23
+ if msg.role in [MessageRole.USER, MessageRole.ASSISTANT]:
24
+ role = 'model' if msg.role == MessageRole.ASSISTANT else 'user'
25
+ history.append({"role": role, "parts": [{"text": msg.content}]})
26
+ return history
27
+
14
28
  class GeminiLLM(BaseLLM):
15
29
  def __init__(self, model: LLMModel = None, llm_config: LLMConfig = None):
16
- self.generation_config = {
17
- "temperature": 0,
18
- "top_p": 0.95,
19
- "top_k": 64,
20
- "max_output_tokens": 8192,
21
- "response_mime_type": "text/plain",
22
- }
23
-
24
- # Provide defaults if not specified
25
30
  if model is None:
26
- model = LLMModel.GEMINI_1_5_FLASH_API
31
+ model = LLMModel['gemini-2.5-flash'] # Note: Ensure model name is compatible, e.g., 'gemini-1.5-flash-latest'
27
32
  if llm_config is None:
28
33
  llm_config = LLMConfig()
29
-
34
+
30
35
  super().__init__(model=model, llm_config=llm_config)
31
- self.client = self.initialize()
32
- self.chat_session = None
36
+
37
+ # CHANGED: Initialization flow. Configure API key and then instantiate the model.
38
+ self.initialize()
39
+
40
+ system_instruction = self.system_message if self.system_message else None
41
+
42
+ self.model = genai.GenerativeModel(
43
+ model_name=self.model.value,
44
+ system_instruction=system_instruction
45
+ )
33
46
 
34
- @classmethod
35
- def initialize(cls):
47
+ @staticmethod
48
+ def initialize():
49
+ """
50
+ CHANGED: This method now configures the genai library with the API key
51
+ instead of creating a client instance.
52
+ """
36
53
  api_key = os.environ.get("GEMINI_API_KEY")
37
54
  if not api_key:
38
55
  logger.error("GEMINI_API_KEY environment variable is not set.")
39
- raise ValueError(
40
- "GEMINI_API_KEY environment variable is not set. "
41
- "Please set this variable in your environment."
42
- )
56
+ raise ValueError("GEMINI_API_KEY environment variable is not set.")
43
57
  try:
44
58
  genai.configure(api_key=api_key)
45
- return genai
46
59
  except Exception as e:
47
- logger.error(f"Failed to initialize Gemini client: {str(e)}")
48
- raise ValueError(f"Failed to initialize Gemini client: {str(e)}")
60
+ logger.error(f"Failed to configure Gemini client: {str(e)}")
61
+ raise ValueError(f"Failed to configure Gemini client: {str(e)}")
49
62
 
50
- def _ensure_chat_session(self):
51
- if not self.chat_session:
52
- model = self.client.GenerativeModel(
53
- model_name=self.model.value,
54
- generation_config=self.generation_config
55
- )
56
- history = []
57
- for msg in self.messages:
58
- history.append({"role": msg.role.value, "parts": [msg.content]})
59
- self.chat_session = model.start_chat(history=history)
63
+ def _get_generation_config(self) -> Dict[str, Any]:
64
+ """
65
+ CHANGED: Builds the generation config as a dictionary.
66
+ 'thinking_config' is not available in the old library.
67
+ 'system_instruction' is passed during model initialization.
68
+ """
69
+ # Basic configuration, you can expand this with temperature, top_p, etc.
70
+ # from self.llm_config if needed.
71
+ config = {
72
+ "response_mime_type": "text/plain",
73
+ # Example: "temperature": self.llm_config.temperature
74
+ }
75
+ return config
60
76
 
61
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
77
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
62
78
  self.add_user_message(user_message)
79
+
63
80
  try:
64
- self._ensure_chat_session()
65
- response = self.chat_session.send_message(user_message)
81
+ history = _format_gemini_history(self.messages)
82
+ generation_config = self._get_generation_config()
83
+
84
+ # CHANGED: API call now uses the model instance directly.
85
+ response = await self.model.generate_content_async(
86
+ contents=history,
87
+ generation_config=generation_config,
88
+ )
89
+
66
90
  assistant_message = response.text
67
91
  self.add_assistant_message(assistant_message)
68
92
 
93
+ # CHANGED: Token usage is extracted from 'usage_metadata'.
69
94
  token_usage = TokenUsage(
70
- prompt_tokens=0,
71
- completion_tokens=0,
72
- total_tokens=0
95
+ prompt_tokens=response.usage_metadata.prompt_token_count,
96
+ completion_tokens=response.usage_metadata.candidates_token_count,
97
+ total_tokens=response.usage_metadata.total_token_count
73
98
  )
74
99
 
75
100
  return CompleteResponse(
@@ -80,6 +105,47 @@ class GeminiLLM(BaseLLM):
80
105
  logger.error(f"Error in Gemini API call: {str(e)}")
81
106
  raise ValueError(f"Error in Gemini API call: {str(e)}")
82
107
 
108
+ async def _stream_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> AsyncGenerator[ChunkResponse, None]:
109
+ self.add_user_message(user_message)
110
+ complete_response = ""
111
+
112
+ try:
113
+ history = _format_gemini_history(self.messages)
114
+ generation_config = self._get_generation_config()
115
+
116
+ # CHANGED: API call for streaming is now part of generate_content_async.
117
+ response_stream = await self.model.generate_content_async(
118
+ contents=history,
119
+ generation_config=generation_config,
120
+ stream=True
121
+ )
122
+
123
+ async for chunk in response_stream:
124
+ chunk_text = chunk.text
125
+ complete_response += chunk_text
126
+ yield ChunkResponse(
127
+ content=chunk_text,
128
+ is_complete=False
129
+ )
130
+
131
+ self.add_assistant_message(complete_response)
132
+
133
+ # NOTE: The old library's async stream does not easily expose token usage.
134
+ # Keeping it at 0, consistent with your original implementation.
135
+ token_usage = TokenUsage(
136
+ prompt_tokens=0,
137
+ completion_tokens=0,
138
+ total_tokens=0
139
+ )
140
+
141
+ yield ChunkResponse(
142
+ content="",
143
+ is_complete=True,
144
+ usage=token_usage
145
+ )
146
+ except Exception as e:
147
+ logger.error(f"Error in Gemini API streaming call: {str(e)}")
148
+ raise ValueError(f"Error in Gemini API streaming call: {str(e)}")
149
+
83
150
  async def cleanup(self):
84
- self.chat_session = None
85
- super().cleanup()
151
+ await super().cleanup()
@@ -7,6 +7,7 @@ from autobyteus.llm.utils.llm_config import LLMConfig
7
7
  from autobyteus.llm.utils.messages import MessageRole, Message
8
8
  from autobyteus.llm.utils.token_usage import TokenUsage
9
9
  from autobyteus.llm.utils.response_types import CompleteResponse, ChunkResponse
10
+ from autobyteus.llm.user_message import LLMUserMessage
10
11
 
11
12
  logger = logging.getLogger(__name__)
12
13
 
@@ -36,7 +37,7 @@ class GroqLLM(BaseLLM):
36
37
  except Exception as e:
37
38
  raise ValueError(f"Failed to initialize Groq client: {str(e)}")
38
39
 
39
- async def _send_user_message_to_llm(self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs) -> CompleteResponse:
40
+ async def _send_user_message_to_llm(self, user_message: LLMUserMessage, **kwargs) -> CompleteResponse:
40
41
  self.add_user_message(user_message)
41
42
  try:
42
43
  # Placeholder for sending message to Groq API
@@ -58,7 +59,7 @@ class GroqLLM(BaseLLM):
58
59
  raise ValueError(f"Error in Groq API call: {str(e)}")
59
60
 
60
61
  async def _stream_user_message_to_llm(
61
- self, user_message: str, image_urls: Optional[List[str]] = None, **kwargs
62
+ self, user_message: LLMUserMessage, **kwargs
62
63
  ) -> AsyncGenerator[ChunkResponse, None]:
63
64
  self.add_user_message(user_message)
64
65
  complete_response = ""
@@ -90,4 +91,4 @@ class GroqLLM(BaseLLM):
90
91
  raise ValueError(f"Error in Groq API streaming: {str(e)}")
91
92
 
92
93
  async def cleanup(self):
93
- super().cleanup()
94
+ await super().cleanup()