autobyteus 1.1.5__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. autobyteus/agent/context/agent_config.py +6 -1
  2. autobyteus/agent/handlers/llm_user_message_ready_event_handler.py +30 -7
  3. autobyteus/agent/handlers/user_input_message_event_handler.py +22 -25
  4. autobyteus/agent/message/__init__.py +7 -5
  5. autobyteus/agent/message/agent_input_user_message.py +6 -16
  6. autobyteus/agent/message/context_file.py +24 -24
  7. autobyteus/agent/message/context_file_type.py +29 -8
  8. autobyteus/agent/message/multimodal_message_builder.py +47 -0
  9. autobyteus/agent/streaming/stream_event_payloads.py +23 -4
  10. autobyteus/agent/system_prompt_processor/tool_manifest_injector_processor.py +6 -2
  11. autobyteus/agent/tool_invocation.py +2 -1
  12. autobyteus/agent_team/bootstrap_steps/agent_configuration_preparation_step.py +9 -2
  13. autobyteus/agent_team/context/agent_team_config.py +1 -0
  14. autobyteus/llm/api/autobyteus_llm.py +33 -33
  15. autobyteus/llm/api/bedrock_llm.py +13 -5
  16. autobyteus/llm/api/claude_llm.py +13 -27
  17. autobyteus/llm/api/gemini_llm.py +108 -42
  18. autobyteus/llm/api/groq_llm.py +4 -3
  19. autobyteus/llm/api/mistral_llm.py +97 -51
  20. autobyteus/llm/api/nvidia_llm.py +6 -5
  21. autobyteus/llm/api/ollama_llm.py +37 -12
  22. autobyteus/llm/api/openai_compatible_llm.py +91 -91
  23. autobyteus/llm/autobyteus_provider.py +1 -1
  24. autobyteus/llm/base_llm.py +42 -139
  25. autobyteus/llm/extensions/base_extension.py +6 -6
  26. autobyteus/llm/extensions/token_usage_tracking_extension.py +3 -2
  27. autobyteus/llm/llm_factory.py +106 -4
  28. autobyteus/llm/token_counter/token_counter_factory.py +1 -1
  29. autobyteus/llm/user_message.py +43 -35
  30. autobyteus/llm/utils/llm_config.py +34 -18
  31. autobyteus/llm/utils/media_payload_formatter.py +99 -0
  32. autobyteus/llm/utils/messages.py +32 -25
  33. autobyteus/llm/utils/response_types.py +9 -3
  34. autobyteus/llm/utils/token_usage.py +6 -5
  35. autobyteus/multimedia/__init__.py +31 -0
  36. autobyteus/multimedia/audio/__init__.py +11 -0
  37. autobyteus/multimedia/audio/api/__init__.py +4 -0
  38. autobyteus/multimedia/audio/api/autobyteus_audio_client.py +59 -0
  39. autobyteus/multimedia/audio/api/gemini_audio_client.py +219 -0
  40. autobyteus/multimedia/audio/audio_client_factory.py +120 -0
  41. autobyteus/multimedia/audio/audio_model.py +96 -0
  42. autobyteus/multimedia/audio/autobyteus_audio_provider.py +108 -0
  43. autobyteus/multimedia/audio/base_audio_client.py +40 -0
  44. autobyteus/multimedia/image/__init__.py +11 -0
  45. autobyteus/multimedia/image/api/__init__.py +9 -0
  46. autobyteus/multimedia/image/api/autobyteus_image_client.py +97 -0
  47. autobyteus/multimedia/image/api/gemini_image_client.py +188 -0
  48. autobyteus/multimedia/image/api/openai_image_client.py +142 -0
  49. autobyteus/multimedia/image/autobyteus_image_provider.py +109 -0
  50. autobyteus/multimedia/image/base_image_client.py +67 -0
  51. autobyteus/multimedia/image/image_client_factory.py +118 -0
  52. autobyteus/multimedia/image/image_model.py +96 -0
  53. autobyteus/multimedia/providers.py +5 -0
  54. autobyteus/multimedia/runtimes.py +8 -0
  55. autobyteus/multimedia/utils/__init__.py +10 -0
  56. autobyteus/multimedia/utils/api_utils.py +19 -0
  57. autobyteus/multimedia/utils/multimedia_config.py +29 -0
  58. autobyteus/multimedia/utils/response_types.py +13 -0
  59. autobyteus/tools/__init__.py +3 -0
  60. autobyteus/tools/multimedia/__init__.py +8 -0
  61. autobyteus/tools/multimedia/audio_tools.py +116 -0
  62. autobyteus/tools/multimedia/image_tools.py +186 -0
  63. autobyteus/tools/tool_category.py +1 -0
  64. autobyteus/tools/usage/parsers/provider_aware_tool_usage_parser.py +5 -2
  65. autobyteus/tools/usage/providers/tool_manifest_provider.py +5 -3
  66. autobyteus/tools/usage/registries/tool_formatting_registry.py +9 -2
  67. autobyteus/tools/usage/registries/tool_usage_parser_registry.py +9 -2
  68. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/METADATA +9 -9
  69. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/RECORD +73 -45
  70. examples/run_browser_agent.py +1 -1
  71. autobyteus/llm/utils/image_payload_formatter.py +0 -89
  72. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/WHEEL +0 -0
  73. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/licenses/LICENSE +0 -0
  74. {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/top_level.txt +0 -0
@@ -55,6 +55,7 @@ class LLMConfig:
55
55
  frequency_penalty: Optional[float] = None
56
56
  presence_penalty: Optional[float] = None
57
57
  stop_sequences: Optional[List] = None
58
+ uses_max_completion_tokens: bool = False
58
59
  extra_params: Dict[str, Any] = field(default_factory=dict)
59
60
  pricing_config: TokenPricingConfig = field(default_factory=TokenPricingConfig)
60
61
 
@@ -102,17 +103,28 @@ class LLMConfig:
102
103
  data_copy = data.copy()
103
104
  pricing_config_data = data_copy.pop('pricing_config', {})
104
105
 
106
+ # Create a new dictionary for known fields to avoid passing them in twice
107
+ known_fields = {
108
+ 'rate_limit', 'token_limit', 'system_message', 'temperature',
109
+ 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty',
110
+ 'stop_sequences', 'uses_max_completion_tokens', 'extra_params',
111
+ 'pricing_config'
112
+ }
113
+
114
+ init_kwargs = {k: v for k, v in data_copy.items() if k in known_fields}
115
+
105
116
  config = cls(
106
- rate_limit=data_copy.get('rate_limit'),
107
- token_limit=data_copy.get('token_limit'),
108
- system_message=data_copy.get('system_message', "You are a helpful assistant."),
109
- temperature=data_copy.get('temperature', 0.7),
110
- max_tokens=data_copy.get('max_tokens'),
111
- top_p=data_copy.get('top_p'),
112
- frequency_penalty=data_copy.get('frequency_penalty'),
113
- presence_penalty=data_copy.get('presence_penalty'),
114
- stop_sequences=data_copy.get('stop_sequences'),
115
- extra_params=data_copy.get('extra_params', {}),
117
+ rate_limit=init_kwargs.get('rate_limit'),
118
+ token_limit=init_kwargs.get('token_limit'),
119
+ system_message=init_kwargs.get('system_message', "You are a helpful assistant."),
120
+ temperature=init_kwargs.get('temperature', 0.7),
121
+ max_tokens=init_kwargs.get('max_tokens'),
122
+ top_p=init_kwargs.get('top_p'),
123
+ frequency_penalty=init_kwargs.get('frequency_penalty'),
124
+ presence_penalty=init_kwargs.get('presence_penalty'),
125
+ stop_sequences=init_kwargs.get('stop_sequences'),
126
+ uses_max_completion_tokens=init_kwargs.get('uses_max_completion_tokens', False),
127
+ extra_params=init_kwargs.get('extra_params', {}),
116
128
  pricing_config=pricing_config_data
117
129
  )
118
130
  return config
@@ -162,26 +174,30 @@ class LLMConfig:
162
174
  for f_info in fields(override_config):
163
175
  override_value = getattr(override_config, f_info.name)
164
176
 
177
+ # Special handling for booleans where we want to merge if it's not the default
178
+ # For `uses_max_completion_tokens`, the default is False, so `if override_value:` is fine
179
+ is_boolean_field = f_info.type == bool
180
+
181
+ # Standard check for None, but also merge if it's a non-default boolean
165
182
  if override_value is not None:
166
- if f_info.name == 'pricing_config':
167
- # Ensure self.pricing_config is an object (should be by __post_init__)
183
+ # For uses_max_completion_tokens, `False` is a valid override value, but `None` is not
184
+ if is_boolean_field and override_value is False and getattr(self, f_info.name) is True:
185
+ setattr(self, f_info.name, override_value)
186
+ elif f_info.name == 'pricing_config':
168
187
  if not isinstance(self.pricing_config, TokenPricingConfig):
169
- self.pricing_config = TokenPricingConfig() # Should not be needed
188
+ self.pricing_config = TokenPricingConfig()
170
189
 
171
- # override_value here is override_config.pricing_config, which is TokenPricingConfig
172
190
  if isinstance(override_value, TokenPricingConfig):
173
191
  self.pricing_config.merge_with(override_value)
174
- elif isinstance(override_value, dict): # Should not happen if override_config is LLMConfig
192
+ elif isinstance(override_value, dict):
175
193
  self.pricing_config.merge_with(TokenPricingConfig.from_dict(override_value))
176
194
  else:
177
195
  logger.warning(f"Skipping merge for pricing_config due to unexpected override type: {type(override_value)}")
178
196
  elif f_info.name == 'extra_params':
179
- # For extra_params (dict), merge dictionaries
180
197
  if isinstance(override_value, dict) and isinstance(self.extra_params, dict):
181
198
  self.extra_params.update(override_value)
182
199
  else:
183
- setattr(self, f_info.name, override_value) # Fallback to direct set if types mismatch
200
+ setattr(self, f_info.name, override_value)
184
201
  else:
185
202
  setattr(self, f_info.name, override_value)
186
203
  logger.debug(f"LLMConfig merged. Current state after merge: rate_limit={self.rate_limit}, temp={self.temperature}, system_message='{self.system_message}'")
187
-
@@ -0,0 +1,99 @@
1
+ import base64
2
+ import mimetypes
3
+ from typing import Dict, Union
4
+ from pathlib import Path
5
+ import httpx
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # FIX: Instantiate the client with verify=False to allow for self-signed certificates
11
+ # in local development environments, which is a common use case.
12
+ _http_client = httpx.AsyncClient(verify=False)
13
+
14
+ # Add a prominent security warning to inform developers about the disabled SSL verification.
15
+ logger.warning(
16
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
17
+ "SECURITY WARNING: SSL certificate verification is DISABLED for the image "
18
+ "downloader (httpx client in media_payload_formatter.py).\n"
19
+ "This is intended for development and testing with local servers using "
20
+ "self-signed certificates. In a production environment, this could expose "
21
+ "the system to Man-in-the-Middle (MitM) attacks when downloading images from "
22
+ "the public internet.\n"
23
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
24
+ )
25
+
26
+
27
+ def get_mime_type(file_path: str) -> str:
28
+ """Determine MIME type of file."""
29
+ mime_type, _ = mimetypes.guess_type(file_path)
30
+ if not mime_type or not mime_type.startswith('image/'):
31
+ return 'image/jpeg' # default fallback
32
+ return mime_type
33
+
34
+
35
+ def is_base64(s: str) -> bool:
36
+ """Check if a string is a valid base64 encoded string."""
37
+ try:
38
+ # Check if the string has valid base64 characters and padding
39
+ if not isinstance(s, str) or len(s) % 4 != 0:
40
+ return False
41
+ base64.b64decode(s, validate=True)
42
+ return True
43
+ except (ValueError, TypeError):
44
+ return False
45
+
46
+
47
+ def is_valid_image_path(path: str) -> bool:
48
+ """Check if path exists and has a valid image extension."""
49
+ valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
50
+ try:
51
+ file_path = Path(path)
52
+ return file_path.is_file() and file_path.suffix.lower() in valid_extensions
53
+ except (TypeError, ValueError):
54
+ return False
55
+
56
+
57
+ def create_data_uri(mime_type: str, base64_data: str) -> Dict:
58
+ """Create properly structured data URI object for API."""
59
+ return {
60
+ "type": "image_url",
61
+ "image_url": {
62
+ "url": f"data:{mime_type};base64,{base64_data}"
63
+ }
64
+ }
65
+
66
+ def file_to_base64(path: str) -> str:
67
+ """Reads an image file from a local path and returns it as a base64 encoded string."""
68
+ try:
69
+ with open(path, "rb") as img_file:
70
+ return base64.b64encode(img_file.read()).decode("utf-8")
71
+ except Exception as e:
72
+ logger.error(f"Failed to read and encode image file at {path}: {e}")
73
+ raise
74
+
75
+ async def url_to_base64(url: str) -> str:
76
+ """Downloads an image from a URL and returns it as a base64 encoded string."""
77
+ try:
78
+ response = await _http_client.get(url)
79
+ response.raise_for_status()
80
+ return base64.b64encode(response.content).decode("utf-8")
81
+ except httpx.HTTPError as e:
82
+ logger.error(f"Failed to download image from URL {url}: {e}")
83
+ raise
84
+
85
+ async def image_source_to_base64(image_source: str) -> str:
86
+ """
87
+ Orchestrator function that converts an image source (file path, URL, or existing base64)
88
+ into a base64 encoded string by delegating to specialized functions.
89
+ """
90
+ if is_valid_image_path(image_source):
91
+ return file_to_base64(image_source)
92
+
93
+ if image_source.startswith(("http://", "https://")):
94
+ return await url_to_base64(image_source)
95
+
96
+ if is_base64(image_source):
97
+ return image_source
98
+
99
+ raise ValueError(f"Invalid image source: not a valid file path, URL, or base64 string.")
@@ -7,34 +7,41 @@ class MessageRole(Enum):
7
7
  ASSISTANT = "assistant"
8
8
 
9
9
  class Message:
10
- def __init__(self, role: MessageRole, content: Union[str, List[Dict]], reasoning_content: Optional[str] = None):
10
+ def __init__(self,
11
+ role: MessageRole,
12
+ content: Optional[str] = None,
13
+ reasoning_content: Optional[str] = None,
14
+ image_urls: Optional[List[str]] = None,
15
+ audio_urls: Optional[List[str]] = None,
16
+ video_urls: Optional[List[str]] = None):
11
17
  """
12
- Initializes a Message.
13
-
18
+ Initializes a rich Message object for conversation history.
19
+
14
20
  Args:
15
- role (MessageRole): The role of the message.
16
- content (Union[str, List[Dict]]): The content of the message.
17
- reasoning_content (Optional[str]): Optional reasoning content for reasoning models.
21
+ role: The role of the message originator.
22
+ content: The textual content of the message.
23
+ reasoning_content: Optional reasoning/thought process from an assistant.
24
+ image_urls: Optional list of image URIs.
25
+ audio_urls: Optional list of audio URIs.
26
+ video_urls: Optional list of video URIs.
18
27
  """
19
28
  self.role = role
20
29
  self.content = content
21
- self.reasoning_content = reasoning_content # Optional field for reasoning content
22
-
23
- def to_dict(self) -> Dict[str, Union[str, List[Dict]]]:
24
- result: Dict[str, Union[str, List[Dict]]] = {"role": self.role.value, "content": self.content}
25
- if self.reasoning_content:
26
- result["reasoning_content"] = self.reasoning_content
27
- return result
30
+ self.reasoning_content = reasoning_content
31
+ self.image_urls = image_urls or []
32
+ self.audio_urls = audio_urls or []
33
+ self.video_urls = video_urls or []
28
34
 
29
- def to_mistral_message(self):
30
- if self.role == MessageRole.USER:
31
- from mistralai import UserMessage
32
- return UserMessage(content=self.content)
33
- elif self.role == MessageRole.ASSISTANT:
34
- from mistralai import AssistantMessage
35
- return AssistantMessage(content=self.content)
36
- elif self.role == MessageRole.SYSTEM:
37
- from mistralai import SystemMessage
38
- return SystemMessage(content=self.content)
39
- else:
40
- raise ValueError(f"Unsupported message role: {self.role}")
35
+ def to_dict(self) -> Dict[str, Union[str, List[str], None]]:
36
+ """
37
+ Returns a simple dictionary representation of the Message object.
38
+ This is for internal use and does not format for any specific API.
39
+ """
40
+ return {
41
+ "role": self.role.value,
42
+ "content": self.content,
43
+ "reasoning_content": self.reasoning_content,
44
+ "image_urls": self.image_urls,
45
+ "audio_urls": self.audio_urls,
46
+ "video_urls": self.video_urls,
47
+ }
@@ -1,5 +1,5 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
3
  from autobyteus.llm.utils.token_usage import TokenUsage
4
4
 
5
5
  @dataclass
@@ -7,6 +7,9 @@ class CompleteResponse:
7
7
  content: str
8
8
  reasoning: Optional[str] = None
9
9
  usage: Optional[TokenUsage] = None
10
+ image_urls: List[str] = field(default_factory=list)
11
+ audio_urls: List[str] = field(default_factory=list)
12
+ video_urls: List[str] = field(default_factory=list)
10
13
 
11
14
  @classmethod
12
15
  def from_content(cls, content: str) -> 'CompleteResponse':
@@ -17,4 +20,7 @@ class ChunkResponse:
17
20
  content: str # The actual content/text of the chunk
18
21
  reasoning: Optional[str] = None
19
22
  is_complete: bool = False # Indicates if this is the final chunk
20
- usage: Optional[TokenUsage] = None # Token usage stats, typically available in final chunk
23
+ usage: Optional[TokenUsage] = None # Token usage stats, typically available in final chunk
24
+ image_urls: List[str] = field(default_factory=list)
25
+ audio_urls: List[str] = field(default_factory=list)
26
+ video_urls: List[str] = field(default_factory=list)
@@ -1,8 +1,8 @@
1
1
  # file: autobyteus/autobyteus/llm/utils/token_usage.py
2
2
  from typing import Optional
3
- from pydantic import BaseModel # MODIFIED: Import BaseModel
3
+ from pydantic import BaseModel, ConfigDict # MODIFIED: Import ConfigDict
4
4
 
5
- # MODIFIED: Change from dataclass to Pydantic BaseModel
5
+ # MODIFIED: Change from dataclass to Pydantic BaseModel and use model_config
6
6
  class TokenUsage(BaseModel):
7
7
  prompt_tokens: int
8
8
  completion_tokens: int
@@ -11,6 +11,7 @@ class TokenUsage(BaseModel):
11
11
  completion_cost: Optional[float] = None
12
12
  total_cost: Optional[float] = None
13
13
 
14
- class Config:
15
- populate_by_name = True # If you use aliases, or for general Pydantic v2 compatibility
16
- # or model_config = ConfigDict(populate_by_name=True) for Pydantic v2
14
+ # FIX: Use model_config with ConfigDict for Pydantic v2 compatibility
15
+ model_config = ConfigDict(
16
+ populate_by_name=True,
17
+ )
@@ -0,0 +1,31 @@
1
+ from .providers import MultimediaProvider
2
+ from .runtimes import MultimediaRuntime
3
+ from .utils import *
4
+ from .image import *
5
+ from .audio import *
6
+
7
+
8
+ __all__ = [
9
+ # Factories
10
+ "image_client_factory",
11
+ "ImageClientFactory",
12
+ "audio_client_factory",
13
+ "AudioClientFactory",
14
+
15
+ # Models
16
+ "ImageModel",
17
+ "AudioModel",
18
+
19
+ # Base Clients
20
+ "BaseImageClient",
21
+ "BaseAudioClient",
22
+
23
+ # Enums
24
+ "MultimediaProvider",
25
+ "MultimediaRuntime",
26
+
27
+ # Response Types and Config
28
+ "ImageGenerationResponse",
29
+ "SpeechGenerationResponse",
30
+ "MultimediaConfig",
31
+ ]
@@ -0,0 +1,11 @@
1
+ from .audio_client_factory import audio_client_factory, AudioClientFactory
2
+ from .audio_model import AudioModel
3
+ from .base_audio_client import BaseAudioClient
4
+ from .api import *
5
+
6
+ __all__ = [
7
+ "audio_client_factory",
8
+ "AudioClientFactory",
9
+ "AudioModel",
10
+ "BaseAudioClient",
11
+ ]
@@ -0,0 +1,4 @@
1
+ from .gemini_audio_client import GeminiAudioClient
2
+ from .autobyteus_audio_client import AutobyteusAudioClient
3
+
4
+ __all__ = ["GeminiAudioClient", "AutobyteusAudioClient"]
@@ -0,0 +1,59 @@
1
+ import logging
2
+ from typing import Optional, List, Dict, Any, TYPE_CHECKING
3
+ from autobyteus_llm_client import AutobyteusClient
4
+ from autobyteus.multimedia.audio.base_audio_client import BaseAudioClient
5
+ from autobyteus.multimedia.utils.response_types import SpeechGenerationResponse
6
+
7
+ if TYPE_CHECKING:
8
+ from autobyteus.multimedia.audio.audio_model import AudioModel
9
+ from autobyteus.multimedia.utils.multimedia_config import MultimediaConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class AutobyteusAudioClient(BaseAudioClient):
14
+ """
15
+ An audio client that connects to an Autobyteus server instance for audio tasks.
16
+ """
17
+
18
+ def __init__(self, model: "AudioModel", config: "MultimediaConfig"):
19
+ super().__init__(model, config)
20
+ if not model.host_url:
21
+ raise ValueError("AutobyteusAudioClient requires a host_url in its AudioModel.")
22
+
23
+ self.autobyteus_client = AutobyteusClient(server_url=model.host_url)
24
+ logger.info(f"AutobyteusAudioClient initialized for model '{model.name}' on host '{model.host_url}'.")
25
+
26
+ async def generate_speech(
27
+ self,
28
+ prompt: str,
29
+ generation_config: Optional[Dict[str, Any]] = None
30
+ ) -> SpeechGenerationResponse:
31
+ """
32
+ Generates speech by calling the generate_speech endpoint on the remote Autobyteus server.
33
+ """
34
+ try:
35
+ logger.info(f"Sending speech generation request for model '{self.model.name}' to {self.model.host_url}")
36
+
37
+ model_name_for_server = self.model.name
38
+
39
+ response_data = await self.autobyteus_client.generate_speech(
40
+ model_name=model_name_for_server,
41
+ prompt=prompt,
42
+ generation_config=generation_config
43
+ )
44
+
45
+ audio_urls = response_data.get("audio_urls", [])
46
+ if not audio_urls:
47
+ raise ValueError("Remote Autobyteus server did not return any audio URLs.")
48
+
49
+ return SpeechGenerationResponse(audio_urls=audio_urls)
50
+
51
+ except Exception as e:
52
+ logger.error(f"Error calling Autobyteus server for speech generation: {e}", exc_info=True)
53
+ raise
54
+
55
+ async def cleanup(self):
56
+ """Closes the underlying AutobyteusClient."""
57
+ if self.autobyteus_client:
58
+ await self.autobyteus_client.close()
59
+ logger.debug("AutobyteusAudioClient cleaned up.")
@@ -0,0 +1,219 @@
1
+ import asyncio
2
+ import base64
3
+ import logging
4
+ import os
5
+ import uuid
6
+ import wave
7
+ from typing import Optional, Dict, Any, TYPE_CHECKING, List
8
+
9
+ # Old/legacy Gemini SDK (as requested)
10
+ import google.generativeai as genai
11
+
12
+ from autobyteus.multimedia.audio.base_audio_client import BaseAudioClient
13
+ from autobyteus.multimedia.utils.response_types import SpeechGenerationResponse
14
+
15
+ if TYPE_CHECKING:
16
+ from autobyteus.multimedia.audio.audio_model import AudioModel
17
+ from autobyteus.multimedia.utils.multimedia_config import MultimediaConfig
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _save_audio_bytes_to_wav(
23
+ pcm_bytes: bytes,
24
+ channels: int = 1,
25
+ rate: int = 24000,
26
+ sample_width: int = 2,
27
+ ) -> str:
28
+ """
29
+ Save raw PCM (s16le) audio bytes to a temporary WAV file and return the file path.
30
+
31
+ Gemini TTS models output mono, 24 kHz, 16-bit PCM by default.
32
+ """
33
+ temp_dir = "/tmp/autobyteus_audio"
34
+ os.makedirs(temp_dir, exist_ok=True)
35
+ file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.wav")
36
+
37
+ try:
38
+ with wave.open(file_path, "wb") as wf:
39
+ wf.setnchannels(channels)
40
+ wf.setsampwidth(sample_width) # 2 bytes => 16-bit
41
+ wf.setframerate(rate)
42
+ wf.writeframes(pcm_bytes)
43
+ logger.info("Successfully saved generated audio to %s", file_path)
44
+ return file_path
45
+ except Exception as e:
46
+ logger.error("Failed to save audio to WAV file at %s: %s", file_path, e)
47
+ raise
48
+
49
+
50
+ def _extract_inline_audio_bytes(response) -> bytes:
51
+ """
52
+ Extract inline audio bytes from a google.generativeai response.
53
+
54
+ The legacy SDK returns a Response object with candidates -> content -> parts[0].inline_data.data.
55
+ Depending on version, `.data` can be bytes or base64-encoded str.
56
+ """
57
+ try:
58
+ # Access the first candidate's first part's inline_data
59
+ part = response.candidates[0].content.parts[0]
60
+ inline = getattr(part, "inline_data", None)
61
+ if not inline or not hasattr(inline, "data"):
62
+ raise ValueError("No inline audio data found in response.")
63
+ data = inline.data
64
+ if isinstance(data, bytes):
65
+ return data
66
+ if isinstance(data, str):
67
+ return base64.b64decode(data)
68
+ raise TypeError(f"Unexpected inline_data.data type: {type(data)}")
69
+ except Exception as e:
70
+ logger.error("Failed to extract audio from response: %s", e)
71
+ raise
72
+
73
+
74
+ class GeminiAudioClient(BaseAudioClient):
75
+ """
76
+ An audio client that uses Google's Gemini models for TTS via the *legacy* SDK
77
+ (`google.generativeai`).
78
+
79
+ Usage notes:
80
+ - Ensure your model value is a TTS-capable model (e.g. "gemini-2.5-flash-preview-tts"
81
+ or "gemini-2.5-pro-preview-tts").
82
+ - Single-speaker is default. For simple usage, provide `voice_name` (e.g. "Kore", "Puck")
83
+ in MultimediaConfig or generation_config.
84
+ - Multi-speaker preview exists in the API; if you want it, pass:
85
+ generation_config = {
86
+ "mode": "multi-speaker",
87
+ "speakers": [
88
+ {"speaker": "Alice", "voice_name": "Kore"},
89
+ {"speaker": "Bob", "voice_name": "Puck"},
90
+ ]
91
+ }
92
+ and make sure your prompt contains lines for each named speaker.
93
+ """
94
+
95
+ def __init__(self, model: "AudioModel", config: "MultimediaConfig"):
96
+ super().__init__(model, config)
97
+ api_key = os.getenv("GEMINI_API_KEY")
98
+ if not api_key:
99
+ raise ValueError("Please set the GEMINI_API_KEY environment variable.")
100
+
101
+ try:
102
+ # Legacy library uses a global configure call
103
+ genai.configure(api_key=api_key)
104
+ # Create a GenerativeModel handle
105
+ self._model = genai.GenerativeModel(self.model.value or "gemini-2.5-flash-preview-tts")
106
+ logger.info("GeminiAudioClient (legacy SDK) configured for model '%s'.", self.model.value)
107
+ except Exception as e:
108
+ logger.error("Failed to configure Gemini client: %s", e)
109
+ raise RuntimeError(f"Failed to configure Gemini client: {e}")
110
+
111
+ @staticmethod
112
+ def _build_single_speaker_generation_config(voice_name: str) -> Dict[str, Any]:
113
+ """
114
+ Build generation_config for single-speaker TTS in the legacy SDK.
115
+ Key bits:
116
+ - response_mime_type => request audio
117
+ - speech_config.voice_config.prebuilt_voice_config.voice_name => set the voice
118
+ """
119
+ return {
120
+ "response_mime_type": "audio/pcm",
121
+ "speech_config": {
122
+ "voice_config": {
123
+ "prebuilt_voice_config": {
124
+ "voice_name": voice_name,
125
+ }
126
+ }
127
+ },
128
+ }
129
+
130
+ @staticmethod
131
+ def _build_multi_speaker_generation_config(speakers: List[Dict[str, str]]) -> Dict[str, Any]:
132
+ """
133
+ Build generation_config for multi-speaker TTS (preview).
134
+ `speakers` = [{"speaker": "...", "voice_name": "..."}, ...]
135
+ """
136
+ speaker_voice_configs = []
137
+ for s in speakers:
138
+ spk = s.get("speaker")
139
+ vname = s.get("voice_name")
140
+ if not spk or not vname:
141
+ raise ValueError("Each speaker must include 'speaker' and 'voice_name'.")
142
+ speaker_voice_configs.append(
143
+ {
144
+ "speaker": spk,
145
+ "voice_config": {
146
+ "prebuilt_voice_config": {
147
+ "voice_name": vname,
148
+ }
149
+ },
150
+ }
151
+ )
152
+ return {
153
+ "response_mime_type": "audio/pcm",
154
+ "speech_config": {
155
+ "multi_speaker_voice_config": {
156
+ "speaker_voice_configs": speaker_voice_configs
157
+ }
158
+ },
159
+ }
160
+
161
+ async def generate_speech(
162
+ self,
163
+ prompt: str,
164
+ generation_config: Optional[Dict[str, Any]] = None
165
+ ) -> SpeechGenerationResponse:
166
+ """
167
+ Generates spoken audio from text using a Gemini TTS model through the legacy SDK.
168
+
169
+ Implementation details:
170
+ - We call `GenerativeModel.generate_content(...)` with a `generation_config`
171
+ that asks for AUDIO and sets the voice settings.
172
+ - The legacy SDK call is synchronous; we offload to a worker thread.
173
+ """
174
+ try:
175
+ logger.info("Generating speech with Gemini TTS (legacy SDK) model '%s'...", self.model.value)
176
+
177
+ # Merge base config with per-call overrides
178
+ final_cfg = self.config.to_dict().copy()
179
+ if generation_config:
180
+ final_cfg.update(generation_config or {})
181
+
182
+ # Style instructions: prepend if provided
183
+ style_instructions = final_cfg.get("style_instructions")
184
+ final_prompt = f"{style_instructions}: {prompt}" if style_instructions else prompt
185
+ logger.debug("Final prompt for TTS (truncated): '%s...'", final_prompt[:160])
186
+
187
+ # Mode & voice
188
+ mode = final_cfg.get("mode", "single-speaker")
189
+ default_voice = final_cfg.get("voice_name", "Kore")
190
+
191
+ if mode == "multi-speaker":
192
+ speakers = final_cfg.get("speakers")
193
+ if not speakers or not isinstance(speakers, list):
194
+ raise ValueError(
195
+ "For multi-speaker mode, provide generation_config['speakers'] "
196
+ "as a list of {'speaker': <name>, 'voice_name': <prebuilt voice>}."
197
+ )
198
+ gen_config = self._build_multi_speaker_generation_config(speakers)
199
+ else:
200
+ gen_config = self._build_single_speaker_generation_config(default_voice)
201
+
202
+ # Run the blocking gen call in a thread so this coroutine stays non-blocking
203
+ response = await asyncio.to_thread(
204
+ self._model.generate_content,
205
+ final_prompt,
206
+ generation_config=gen_config,
207
+ )
208
+
209
+ audio_pcm = _extract_inline_audio_bytes(response)
210
+ audio_path = _save_audio_bytes_to_wav(audio_pcm)
211
+
212
+ return SpeechGenerationResponse(audio_urls=[audio_path])
213
+
214
+ except Exception as e:
215
+ logger.error("Error during Google Gemini speech generation (legacy SDK): %s", str(e))
216
+ raise ValueError(f"Google Gemini speech generation failed: {str(e)}")
217
+
218
+ async def cleanup(self):
219
+ logger.debug("GeminiAudioClient cleanup called (legacy SDK; nothing to release).")