autobyteus 1.1.5__py3-none-any.whl → 1.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autobyteus/agent/context/agent_config.py +6 -1
- autobyteus/agent/handlers/llm_user_message_ready_event_handler.py +30 -7
- autobyteus/agent/handlers/user_input_message_event_handler.py +22 -25
- autobyteus/agent/message/__init__.py +7 -5
- autobyteus/agent/message/agent_input_user_message.py +6 -16
- autobyteus/agent/message/context_file.py +24 -24
- autobyteus/agent/message/context_file_type.py +29 -8
- autobyteus/agent/message/multimodal_message_builder.py +47 -0
- autobyteus/agent/streaming/stream_event_payloads.py +23 -4
- autobyteus/agent/system_prompt_processor/tool_manifest_injector_processor.py +6 -2
- autobyteus/agent/tool_invocation.py +2 -1
- autobyteus/agent_team/bootstrap_steps/agent_configuration_preparation_step.py +9 -2
- autobyteus/agent_team/context/agent_team_config.py +1 -0
- autobyteus/llm/api/autobyteus_llm.py +33 -33
- autobyteus/llm/api/bedrock_llm.py +13 -5
- autobyteus/llm/api/claude_llm.py +13 -27
- autobyteus/llm/api/gemini_llm.py +108 -42
- autobyteus/llm/api/groq_llm.py +4 -3
- autobyteus/llm/api/mistral_llm.py +97 -51
- autobyteus/llm/api/nvidia_llm.py +6 -5
- autobyteus/llm/api/ollama_llm.py +37 -12
- autobyteus/llm/api/openai_compatible_llm.py +91 -91
- autobyteus/llm/autobyteus_provider.py +1 -1
- autobyteus/llm/base_llm.py +42 -139
- autobyteus/llm/extensions/base_extension.py +6 -6
- autobyteus/llm/extensions/token_usage_tracking_extension.py +3 -2
- autobyteus/llm/llm_factory.py +106 -4
- autobyteus/llm/token_counter/token_counter_factory.py +1 -1
- autobyteus/llm/user_message.py +43 -35
- autobyteus/llm/utils/llm_config.py +34 -18
- autobyteus/llm/utils/media_payload_formatter.py +99 -0
- autobyteus/llm/utils/messages.py +32 -25
- autobyteus/llm/utils/response_types.py +9 -3
- autobyteus/llm/utils/token_usage.py +6 -5
- autobyteus/multimedia/__init__.py +31 -0
- autobyteus/multimedia/audio/__init__.py +11 -0
- autobyteus/multimedia/audio/api/__init__.py +4 -0
- autobyteus/multimedia/audio/api/autobyteus_audio_client.py +59 -0
- autobyteus/multimedia/audio/api/gemini_audio_client.py +219 -0
- autobyteus/multimedia/audio/audio_client_factory.py +120 -0
- autobyteus/multimedia/audio/audio_model.py +96 -0
- autobyteus/multimedia/audio/autobyteus_audio_provider.py +108 -0
- autobyteus/multimedia/audio/base_audio_client.py +40 -0
- autobyteus/multimedia/image/__init__.py +11 -0
- autobyteus/multimedia/image/api/__init__.py +9 -0
- autobyteus/multimedia/image/api/autobyteus_image_client.py +97 -0
- autobyteus/multimedia/image/api/gemini_image_client.py +188 -0
- autobyteus/multimedia/image/api/openai_image_client.py +142 -0
- autobyteus/multimedia/image/autobyteus_image_provider.py +109 -0
- autobyteus/multimedia/image/base_image_client.py +67 -0
- autobyteus/multimedia/image/image_client_factory.py +118 -0
- autobyteus/multimedia/image/image_model.py +96 -0
- autobyteus/multimedia/providers.py +5 -0
- autobyteus/multimedia/runtimes.py +8 -0
- autobyteus/multimedia/utils/__init__.py +10 -0
- autobyteus/multimedia/utils/api_utils.py +19 -0
- autobyteus/multimedia/utils/multimedia_config.py +29 -0
- autobyteus/multimedia/utils/response_types.py +13 -0
- autobyteus/tools/__init__.py +3 -0
- autobyteus/tools/multimedia/__init__.py +8 -0
- autobyteus/tools/multimedia/audio_tools.py +116 -0
- autobyteus/tools/multimedia/image_tools.py +186 -0
- autobyteus/tools/tool_category.py +1 -0
- autobyteus/tools/usage/parsers/provider_aware_tool_usage_parser.py +5 -2
- autobyteus/tools/usage/providers/tool_manifest_provider.py +5 -3
- autobyteus/tools/usage/registries/tool_formatting_registry.py +9 -2
- autobyteus/tools/usage/registries/tool_usage_parser_registry.py +9 -2
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/METADATA +9 -9
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/RECORD +73 -45
- examples/run_browser_agent.py +1 -1
- autobyteus/llm/utils/image_payload_formatter.py +0 -89
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/WHEEL +0 -0
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/licenses/LICENSE +0 -0
- {autobyteus-1.1.5.dist-info → autobyteus-1.1.6.dist-info}/top_level.txt +0 -0
|
@@ -55,6 +55,7 @@ class LLMConfig:
|
|
|
55
55
|
frequency_penalty: Optional[float] = None
|
|
56
56
|
presence_penalty: Optional[float] = None
|
|
57
57
|
stop_sequences: Optional[List] = None
|
|
58
|
+
uses_max_completion_tokens: bool = False
|
|
58
59
|
extra_params: Dict[str, Any] = field(default_factory=dict)
|
|
59
60
|
pricing_config: TokenPricingConfig = field(default_factory=TokenPricingConfig)
|
|
60
61
|
|
|
@@ -102,17 +103,28 @@ class LLMConfig:
|
|
|
102
103
|
data_copy = data.copy()
|
|
103
104
|
pricing_config_data = data_copy.pop('pricing_config', {})
|
|
104
105
|
|
|
106
|
+
# Create a new dictionary for known fields to avoid passing them in twice
|
|
107
|
+
known_fields = {
|
|
108
|
+
'rate_limit', 'token_limit', 'system_message', 'temperature',
|
|
109
|
+
'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty',
|
|
110
|
+
'stop_sequences', 'uses_max_completion_tokens', 'extra_params',
|
|
111
|
+
'pricing_config'
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
init_kwargs = {k: v for k, v in data_copy.items() if k in known_fields}
|
|
115
|
+
|
|
105
116
|
config = cls(
|
|
106
|
-
rate_limit=
|
|
107
|
-
token_limit=
|
|
108
|
-
system_message=
|
|
109
|
-
temperature=
|
|
110
|
-
max_tokens=
|
|
111
|
-
top_p=
|
|
112
|
-
frequency_penalty=
|
|
113
|
-
presence_penalty=
|
|
114
|
-
stop_sequences=
|
|
115
|
-
|
|
117
|
+
rate_limit=init_kwargs.get('rate_limit'),
|
|
118
|
+
token_limit=init_kwargs.get('token_limit'),
|
|
119
|
+
system_message=init_kwargs.get('system_message', "You are a helpful assistant."),
|
|
120
|
+
temperature=init_kwargs.get('temperature', 0.7),
|
|
121
|
+
max_tokens=init_kwargs.get('max_tokens'),
|
|
122
|
+
top_p=init_kwargs.get('top_p'),
|
|
123
|
+
frequency_penalty=init_kwargs.get('frequency_penalty'),
|
|
124
|
+
presence_penalty=init_kwargs.get('presence_penalty'),
|
|
125
|
+
stop_sequences=init_kwargs.get('stop_sequences'),
|
|
126
|
+
uses_max_completion_tokens=init_kwargs.get('uses_max_completion_tokens', False),
|
|
127
|
+
extra_params=init_kwargs.get('extra_params', {}),
|
|
116
128
|
pricing_config=pricing_config_data
|
|
117
129
|
)
|
|
118
130
|
return config
|
|
@@ -162,26 +174,30 @@ class LLMConfig:
|
|
|
162
174
|
for f_info in fields(override_config):
|
|
163
175
|
override_value = getattr(override_config, f_info.name)
|
|
164
176
|
|
|
177
|
+
# Special handling for booleans where we want to merge if it's not the default
|
|
178
|
+
# For `uses_max_completion_tokens`, the default is False, so `if override_value:` is fine
|
|
179
|
+
is_boolean_field = f_info.type == bool
|
|
180
|
+
|
|
181
|
+
# Standard check for None, but also merge if it's a non-default boolean
|
|
165
182
|
if override_value is not None:
|
|
166
|
-
|
|
167
|
-
|
|
183
|
+
# For uses_max_completion_tokens, `False` is a valid override value, but `None` is not
|
|
184
|
+
if is_boolean_field and override_value is False and getattr(self, f_info.name) is True:
|
|
185
|
+
setattr(self, f_info.name, override_value)
|
|
186
|
+
elif f_info.name == 'pricing_config':
|
|
168
187
|
if not isinstance(self.pricing_config, TokenPricingConfig):
|
|
169
|
-
self.pricing_config = TokenPricingConfig()
|
|
188
|
+
self.pricing_config = TokenPricingConfig()
|
|
170
189
|
|
|
171
|
-
# override_value here is override_config.pricing_config, which is TokenPricingConfig
|
|
172
190
|
if isinstance(override_value, TokenPricingConfig):
|
|
173
191
|
self.pricing_config.merge_with(override_value)
|
|
174
|
-
elif isinstance(override_value, dict):
|
|
192
|
+
elif isinstance(override_value, dict):
|
|
175
193
|
self.pricing_config.merge_with(TokenPricingConfig.from_dict(override_value))
|
|
176
194
|
else:
|
|
177
195
|
logger.warning(f"Skipping merge for pricing_config due to unexpected override type: {type(override_value)}")
|
|
178
196
|
elif f_info.name == 'extra_params':
|
|
179
|
-
# For extra_params (dict), merge dictionaries
|
|
180
197
|
if isinstance(override_value, dict) and isinstance(self.extra_params, dict):
|
|
181
198
|
self.extra_params.update(override_value)
|
|
182
199
|
else:
|
|
183
|
-
setattr(self, f_info.name, override_value)
|
|
200
|
+
setattr(self, f_info.name, override_value)
|
|
184
201
|
else:
|
|
185
202
|
setattr(self, f_info.name, override_value)
|
|
186
203
|
logger.debug(f"LLMConfig merged. Current state after merge: rate_limit={self.rate_limit}, temp={self.temperature}, system_message='{self.system_message}'")
|
|
187
|
-
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import mimetypes
|
|
3
|
+
from typing import Dict, Union
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import httpx
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# FIX: Instantiate the client with verify=False to allow for self-signed certificates
|
|
11
|
+
# in local development environments, which is a common use case.
|
|
12
|
+
_http_client = httpx.AsyncClient(verify=False)
|
|
13
|
+
|
|
14
|
+
# Add a prominent security warning to inform developers about the disabled SSL verification.
|
|
15
|
+
logger.warning(
|
|
16
|
+
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
|
|
17
|
+
"SECURITY WARNING: SSL certificate verification is DISABLED for the image "
|
|
18
|
+
"downloader (httpx client in media_payload_formatter.py).\n"
|
|
19
|
+
"This is intended for development and testing with local servers using "
|
|
20
|
+
"self-signed certificates. In a production environment, this could expose "
|
|
21
|
+
"the system to Man-in-the-Middle (MitM) attacks when downloading images from "
|
|
22
|
+
"the public internet.\n"
|
|
23
|
+
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_mime_type(file_path: str) -> str:
|
|
28
|
+
"""Determine MIME type of file."""
|
|
29
|
+
mime_type, _ = mimetypes.guess_type(file_path)
|
|
30
|
+
if not mime_type or not mime_type.startswith('image/'):
|
|
31
|
+
return 'image/jpeg' # default fallback
|
|
32
|
+
return mime_type
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def is_base64(s: str) -> bool:
|
|
36
|
+
"""Check if a string is a valid base64 encoded string."""
|
|
37
|
+
try:
|
|
38
|
+
# Check if the string has valid base64 characters and padding
|
|
39
|
+
if not isinstance(s, str) or len(s) % 4 != 0:
|
|
40
|
+
return False
|
|
41
|
+
base64.b64decode(s, validate=True)
|
|
42
|
+
return True
|
|
43
|
+
except (ValueError, TypeError):
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def is_valid_image_path(path: str) -> bool:
|
|
48
|
+
"""Check if path exists and has a valid image extension."""
|
|
49
|
+
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
|
50
|
+
try:
|
|
51
|
+
file_path = Path(path)
|
|
52
|
+
return file_path.is_file() and file_path.suffix.lower() in valid_extensions
|
|
53
|
+
except (TypeError, ValueError):
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_data_uri(mime_type: str, base64_data: str) -> Dict:
|
|
58
|
+
"""Create properly structured data URI object for API."""
|
|
59
|
+
return {
|
|
60
|
+
"type": "image_url",
|
|
61
|
+
"image_url": {
|
|
62
|
+
"url": f"data:{mime_type};base64,{base64_data}"
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
def file_to_base64(path: str) -> str:
|
|
67
|
+
"""Reads an image file from a local path and returns it as a base64 encoded string."""
|
|
68
|
+
try:
|
|
69
|
+
with open(path, "rb") as img_file:
|
|
70
|
+
return base64.b64encode(img_file.read()).decode("utf-8")
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f"Failed to read and encode image file at {path}: {e}")
|
|
73
|
+
raise
|
|
74
|
+
|
|
75
|
+
async def url_to_base64(url: str) -> str:
|
|
76
|
+
"""Downloads an image from a URL and returns it as a base64 encoded string."""
|
|
77
|
+
try:
|
|
78
|
+
response = await _http_client.get(url)
|
|
79
|
+
response.raise_for_status()
|
|
80
|
+
return base64.b64encode(response.content).decode("utf-8")
|
|
81
|
+
except httpx.HTTPError as e:
|
|
82
|
+
logger.error(f"Failed to download image from URL {url}: {e}")
|
|
83
|
+
raise
|
|
84
|
+
|
|
85
|
+
async def image_source_to_base64(image_source: str) -> str:
|
|
86
|
+
"""
|
|
87
|
+
Orchestrator function that converts an image source (file path, URL, or existing base64)
|
|
88
|
+
into a base64 encoded string by delegating to specialized functions.
|
|
89
|
+
"""
|
|
90
|
+
if is_valid_image_path(image_source):
|
|
91
|
+
return file_to_base64(image_source)
|
|
92
|
+
|
|
93
|
+
if image_source.startswith(("http://", "https://")):
|
|
94
|
+
return await url_to_base64(image_source)
|
|
95
|
+
|
|
96
|
+
if is_base64(image_source):
|
|
97
|
+
return image_source
|
|
98
|
+
|
|
99
|
+
raise ValueError(f"Invalid image source: not a valid file path, URL, or base64 string.")
|
autobyteus/llm/utils/messages.py
CHANGED
|
@@ -7,34 +7,41 @@ class MessageRole(Enum):
|
|
|
7
7
|
ASSISTANT = "assistant"
|
|
8
8
|
|
|
9
9
|
class Message:
|
|
10
|
-
def __init__(self,
|
|
10
|
+
def __init__(self,
|
|
11
|
+
role: MessageRole,
|
|
12
|
+
content: Optional[str] = None,
|
|
13
|
+
reasoning_content: Optional[str] = None,
|
|
14
|
+
image_urls: Optional[List[str]] = None,
|
|
15
|
+
audio_urls: Optional[List[str]] = None,
|
|
16
|
+
video_urls: Optional[List[str]] = None):
|
|
11
17
|
"""
|
|
12
|
-
Initializes a Message.
|
|
13
|
-
|
|
18
|
+
Initializes a rich Message object for conversation history.
|
|
19
|
+
|
|
14
20
|
Args:
|
|
15
|
-
role
|
|
16
|
-
content
|
|
17
|
-
reasoning_content
|
|
21
|
+
role: The role of the message originator.
|
|
22
|
+
content: The textual content of the message.
|
|
23
|
+
reasoning_content: Optional reasoning/thought process from an assistant.
|
|
24
|
+
image_urls: Optional list of image URIs.
|
|
25
|
+
audio_urls: Optional list of audio URIs.
|
|
26
|
+
video_urls: Optional list of video URIs.
|
|
18
27
|
"""
|
|
19
28
|
self.role = role
|
|
20
29
|
self.content = content
|
|
21
|
-
self.reasoning_content = reasoning_content
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
if self.reasoning_content:
|
|
26
|
-
result["reasoning_content"] = self.reasoning_content
|
|
27
|
-
return result
|
|
30
|
+
self.reasoning_content = reasoning_content
|
|
31
|
+
self.image_urls = image_urls or []
|
|
32
|
+
self.audio_urls = audio_urls or []
|
|
33
|
+
self.video_urls = video_urls or []
|
|
28
34
|
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
35
|
+
def to_dict(self) -> Dict[str, Union[str, List[str], None]]:
|
|
36
|
+
"""
|
|
37
|
+
Returns a simple dictionary representation of the Message object.
|
|
38
|
+
This is for internal use and does not format for any specific API.
|
|
39
|
+
"""
|
|
40
|
+
return {
|
|
41
|
+
"role": self.role.value,
|
|
42
|
+
"content": self.content,
|
|
43
|
+
"reasoning_content": self.reasoning_content,
|
|
44
|
+
"image_urls": self.image_urls,
|
|
45
|
+
"audio_urls": self.audio_urls,
|
|
46
|
+
"video_urls": self.video_urls,
|
|
47
|
+
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from typing import Optional
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Optional, List
|
|
3
3
|
from autobyteus.llm.utils.token_usage import TokenUsage
|
|
4
4
|
|
|
5
5
|
@dataclass
|
|
@@ -7,6 +7,9 @@ class CompleteResponse:
|
|
|
7
7
|
content: str
|
|
8
8
|
reasoning: Optional[str] = None
|
|
9
9
|
usage: Optional[TokenUsage] = None
|
|
10
|
+
image_urls: List[str] = field(default_factory=list)
|
|
11
|
+
audio_urls: List[str] = field(default_factory=list)
|
|
12
|
+
video_urls: List[str] = field(default_factory=list)
|
|
10
13
|
|
|
11
14
|
@classmethod
|
|
12
15
|
def from_content(cls, content: str) -> 'CompleteResponse':
|
|
@@ -17,4 +20,7 @@ class ChunkResponse:
|
|
|
17
20
|
content: str # The actual content/text of the chunk
|
|
18
21
|
reasoning: Optional[str] = None
|
|
19
22
|
is_complete: bool = False # Indicates if this is the final chunk
|
|
20
|
-
usage: Optional[TokenUsage] = None # Token usage stats, typically available in final chunk
|
|
23
|
+
usage: Optional[TokenUsage] = None # Token usage stats, typically available in final chunk
|
|
24
|
+
image_urls: List[str] = field(default_factory=list)
|
|
25
|
+
audio_urls: List[str] = field(default_factory=list)
|
|
26
|
+
video_urls: List[str] = field(default_factory=list)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
# file: autobyteus/autobyteus/llm/utils/token_usage.py
|
|
2
2
|
from typing import Optional
|
|
3
|
-
from pydantic import BaseModel # MODIFIED: Import
|
|
3
|
+
from pydantic import BaseModel, ConfigDict # MODIFIED: Import ConfigDict
|
|
4
4
|
|
|
5
|
-
# MODIFIED: Change from dataclass to Pydantic BaseModel
|
|
5
|
+
# MODIFIED: Change from dataclass to Pydantic BaseModel and use model_config
|
|
6
6
|
class TokenUsage(BaseModel):
|
|
7
7
|
prompt_tokens: int
|
|
8
8
|
completion_tokens: int
|
|
@@ -11,6 +11,7 @@ class TokenUsage(BaseModel):
|
|
|
11
11
|
completion_cost: Optional[float] = None
|
|
12
12
|
total_cost: Optional[float] = None
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
# FIX: Use model_config with ConfigDict for Pydantic v2 compatibility
|
|
15
|
+
model_config = ConfigDict(
|
|
16
|
+
populate_by_name=True,
|
|
17
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from .providers import MultimediaProvider
|
|
2
|
+
from .runtimes import MultimediaRuntime
|
|
3
|
+
from .utils import *
|
|
4
|
+
from .image import *
|
|
5
|
+
from .audio import *
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
# Factories
|
|
10
|
+
"image_client_factory",
|
|
11
|
+
"ImageClientFactory",
|
|
12
|
+
"audio_client_factory",
|
|
13
|
+
"AudioClientFactory",
|
|
14
|
+
|
|
15
|
+
# Models
|
|
16
|
+
"ImageModel",
|
|
17
|
+
"AudioModel",
|
|
18
|
+
|
|
19
|
+
# Base Clients
|
|
20
|
+
"BaseImageClient",
|
|
21
|
+
"BaseAudioClient",
|
|
22
|
+
|
|
23
|
+
# Enums
|
|
24
|
+
"MultimediaProvider",
|
|
25
|
+
"MultimediaRuntime",
|
|
26
|
+
|
|
27
|
+
# Response Types and Config
|
|
28
|
+
"ImageGenerationResponse",
|
|
29
|
+
"SpeechGenerationResponse",
|
|
30
|
+
"MultimediaConfig",
|
|
31
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .audio_client_factory import audio_client_factory, AudioClientFactory
|
|
2
|
+
from .audio_model import AudioModel
|
|
3
|
+
from .base_audio_client import BaseAudioClient
|
|
4
|
+
from .api import *
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"audio_client_factory",
|
|
8
|
+
"AudioClientFactory",
|
|
9
|
+
"AudioModel",
|
|
10
|
+
"BaseAudioClient",
|
|
11
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, List, Dict, Any, TYPE_CHECKING
|
|
3
|
+
from autobyteus_llm_client import AutobyteusClient
|
|
4
|
+
from autobyteus.multimedia.audio.base_audio_client import BaseAudioClient
|
|
5
|
+
from autobyteus.multimedia.utils.response_types import SpeechGenerationResponse
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from autobyteus.multimedia.audio.audio_model import AudioModel
|
|
9
|
+
from autobyteus.multimedia.utils.multimedia_config import MultimediaConfig
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class AutobyteusAudioClient(BaseAudioClient):
|
|
14
|
+
"""
|
|
15
|
+
An audio client that connects to an Autobyteus server instance for audio tasks.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, model: "AudioModel", config: "MultimediaConfig"):
|
|
19
|
+
super().__init__(model, config)
|
|
20
|
+
if not model.host_url:
|
|
21
|
+
raise ValueError("AutobyteusAudioClient requires a host_url in its AudioModel.")
|
|
22
|
+
|
|
23
|
+
self.autobyteus_client = AutobyteusClient(server_url=model.host_url)
|
|
24
|
+
logger.info(f"AutobyteusAudioClient initialized for model '{model.name}' on host '{model.host_url}'.")
|
|
25
|
+
|
|
26
|
+
async def generate_speech(
|
|
27
|
+
self,
|
|
28
|
+
prompt: str,
|
|
29
|
+
generation_config: Optional[Dict[str, Any]] = None
|
|
30
|
+
) -> SpeechGenerationResponse:
|
|
31
|
+
"""
|
|
32
|
+
Generates speech by calling the generate_speech endpoint on the remote Autobyteus server.
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
logger.info(f"Sending speech generation request for model '{self.model.name}' to {self.model.host_url}")
|
|
36
|
+
|
|
37
|
+
model_name_for_server = self.model.name
|
|
38
|
+
|
|
39
|
+
response_data = await self.autobyteus_client.generate_speech(
|
|
40
|
+
model_name=model_name_for_server,
|
|
41
|
+
prompt=prompt,
|
|
42
|
+
generation_config=generation_config
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
audio_urls = response_data.get("audio_urls", [])
|
|
46
|
+
if not audio_urls:
|
|
47
|
+
raise ValueError("Remote Autobyteus server did not return any audio URLs.")
|
|
48
|
+
|
|
49
|
+
return SpeechGenerationResponse(audio_urls=audio_urls)
|
|
50
|
+
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(f"Error calling Autobyteus server for speech generation: {e}", exc_info=True)
|
|
53
|
+
raise
|
|
54
|
+
|
|
55
|
+
async def cleanup(self):
|
|
56
|
+
"""Closes the underlying AutobyteusClient."""
|
|
57
|
+
if self.autobyteus_client:
|
|
58
|
+
await self.autobyteus_client.close()
|
|
59
|
+
logger.debug("AutobyteusAudioClient cleaned up.")
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import uuid
|
|
6
|
+
import wave
|
|
7
|
+
from typing import Optional, Dict, Any, TYPE_CHECKING, List
|
|
8
|
+
|
|
9
|
+
# Old/legacy Gemini SDK (as requested)
|
|
10
|
+
import google.generativeai as genai
|
|
11
|
+
|
|
12
|
+
from autobyteus.multimedia.audio.base_audio_client import BaseAudioClient
|
|
13
|
+
from autobyteus.multimedia.utils.response_types import SpeechGenerationResponse
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from autobyteus.multimedia.audio.audio_model import AudioModel
|
|
17
|
+
from autobyteus.multimedia.utils.multimedia_config import MultimediaConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _save_audio_bytes_to_wav(
|
|
23
|
+
pcm_bytes: bytes,
|
|
24
|
+
channels: int = 1,
|
|
25
|
+
rate: int = 24000,
|
|
26
|
+
sample_width: int = 2,
|
|
27
|
+
) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Save raw PCM (s16le) audio bytes to a temporary WAV file and return the file path.
|
|
30
|
+
|
|
31
|
+
Gemini TTS models output mono, 24 kHz, 16-bit PCM by default.
|
|
32
|
+
"""
|
|
33
|
+
temp_dir = "/tmp/autobyteus_audio"
|
|
34
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
35
|
+
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.wav")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
with wave.open(file_path, "wb") as wf:
|
|
39
|
+
wf.setnchannels(channels)
|
|
40
|
+
wf.setsampwidth(sample_width) # 2 bytes => 16-bit
|
|
41
|
+
wf.setframerate(rate)
|
|
42
|
+
wf.writeframes(pcm_bytes)
|
|
43
|
+
logger.info("Successfully saved generated audio to %s", file_path)
|
|
44
|
+
return file_path
|
|
45
|
+
except Exception as e:
|
|
46
|
+
logger.error("Failed to save audio to WAV file at %s: %s", file_path, e)
|
|
47
|
+
raise
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _extract_inline_audio_bytes(response) -> bytes:
|
|
51
|
+
"""
|
|
52
|
+
Extract inline audio bytes from a google.generativeai response.
|
|
53
|
+
|
|
54
|
+
The legacy SDK returns a Response object with candidates -> content -> parts[0].inline_data.data.
|
|
55
|
+
Depending on version, `.data` can be bytes or base64-encoded str.
|
|
56
|
+
"""
|
|
57
|
+
try:
|
|
58
|
+
# Access the first candidate's first part's inline_data
|
|
59
|
+
part = response.candidates[0].content.parts[0]
|
|
60
|
+
inline = getattr(part, "inline_data", None)
|
|
61
|
+
if not inline or not hasattr(inline, "data"):
|
|
62
|
+
raise ValueError("No inline audio data found in response.")
|
|
63
|
+
data = inline.data
|
|
64
|
+
if isinstance(data, bytes):
|
|
65
|
+
return data
|
|
66
|
+
if isinstance(data, str):
|
|
67
|
+
return base64.b64decode(data)
|
|
68
|
+
raise TypeError(f"Unexpected inline_data.data type: {type(data)}")
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error("Failed to extract audio from response: %s", e)
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GeminiAudioClient(BaseAudioClient):
|
|
75
|
+
"""
|
|
76
|
+
An audio client that uses Google's Gemini models for TTS via the *legacy* SDK
|
|
77
|
+
(`google.generativeai`).
|
|
78
|
+
|
|
79
|
+
Usage notes:
|
|
80
|
+
- Ensure your model value is a TTS-capable model (e.g. "gemini-2.5-flash-preview-tts"
|
|
81
|
+
or "gemini-2.5-pro-preview-tts").
|
|
82
|
+
- Single-speaker is default. For simple usage, provide `voice_name` (e.g. "Kore", "Puck")
|
|
83
|
+
in MultimediaConfig or generation_config.
|
|
84
|
+
- Multi-speaker preview exists in the API; if you want it, pass:
|
|
85
|
+
generation_config = {
|
|
86
|
+
"mode": "multi-speaker",
|
|
87
|
+
"speakers": [
|
|
88
|
+
{"speaker": "Alice", "voice_name": "Kore"},
|
|
89
|
+
{"speaker": "Bob", "voice_name": "Puck"},
|
|
90
|
+
]
|
|
91
|
+
}
|
|
92
|
+
and make sure your prompt contains lines for each named speaker.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, model: "AudioModel", config: "MultimediaConfig"):
|
|
96
|
+
super().__init__(model, config)
|
|
97
|
+
api_key = os.getenv("GEMINI_API_KEY")
|
|
98
|
+
if not api_key:
|
|
99
|
+
raise ValueError("Please set the GEMINI_API_KEY environment variable.")
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
# Legacy library uses a global configure call
|
|
103
|
+
genai.configure(api_key=api_key)
|
|
104
|
+
# Create a GenerativeModel handle
|
|
105
|
+
self._model = genai.GenerativeModel(self.model.value or "gemini-2.5-flash-preview-tts")
|
|
106
|
+
logger.info("GeminiAudioClient (legacy SDK) configured for model '%s'.", self.model.value)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error("Failed to configure Gemini client: %s", e)
|
|
109
|
+
raise RuntimeError(f"Failed to configure Gemini client: {e}")
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _build_single_speaker_generation_config(voice_name: str) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
Build generation_config for single-speaker TTS in the legacy SDK.
|
|
115
|
+
Key bits:
|
|
116
|
+
- response_mime_type => request audio
|
|
117
|
+
- speech_config.voice_config.prebuilt_voice_config.voice_name => set the voice
|
|
118
|
+
"""
|
|
119
|
+
return {
|
|
120
|
+
"response_mime_type": "audio/pcm",
|
|
121
|
+
"speech_config": {
|
|
122
|
+
"voice_config": {
|
|
123
|
+
"prebuilt_voice_config": {
|
|
124
|
+
"voice_name": voice_name,
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
},
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def _build_multi_speaker_generation_config(speakers: List[Dict[str, str]]) -> Dict[str, Any]:
|
|
132
|
+
"""
|
|
133
|
+
Build generation_config for multi-speaker TTS (preview).
|
|
134
|
+
`speakers` = [{"speaker": "...", "voice_name": "..."}, ...]
|
|
135
|
+
"""
|
|
136
|
+
speaker_voice_configs = []
|
|
137
|
+
for s in speakers:
|
|
138
|
+
spk = s.get("speaker")
|
|
139
|
+
vname = s.get("voice_name")
|
|
140
|
+
if not spk or not vname:
|
|
141
|
+
raise ValueError("Each speaker must include 'speaker' and 'voice_name'.")
|
|
142
|
+
speaker_voice_configs.append(
|
|
143
|
+
{
|
|
144
|
+
"speaker": spk,
|
|
145
|
+
"voice_config": {
|
|
146
|
+
"prebuilt_voice_config": {
|
|
147
|
+
"voice_name": vname,
|
|
148
|
+
}
|
|
149
|
+
},
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
return {
|
|
153
|
+
"response_mime_type": "audio/pcm",
|
|
154
|
+
"speech_config": {
|
|
155
|
+
"multi_speaker_voice_config": {
|
|
156
|
+
"speaker_voice_configs": speaker_voice_configs
|
|
157
|
+
}
|
|
158
|
+
},
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
async def generate_speech(
|
|
162
|
+
self,
|
|
163
|
+
prompt: str,
|
|
164
|
+
generation_config: Optional[Dict[str, Any]] = None
|
|
165
|
+
) -> SpeechGenerationResponse:
|
|
166
|
+
"""
|
|
167
|
+
Generates spoken audio from text using a Gemini TTS model through the legacy SDK.
|
|
168
|
+
|
|
169
|
+
Implementation details:
|
|
170
|
+
- We call `GenerativeModel.generate_content(...)` with a `generation_config`
|
|
171
|
+
that asks for AUDIO and sets the voice settings.
|
|
172
|
+
- The legacy SDK call is synchronous; we offload to a worker thread.
|
|
173
|
+
"""
|
|
174
|
+
try:
|
|
175
|
+
logger.info("Generating speech with Gemini TTS (legacy SDK) model '%s'...", self.model.value)
|
|
176
|
+
|
|
177
|
+
# Merge base config with per-call overrides
|
|
178
|
+
final_cfg = self.config.to_dict().copy()
|
|
179
|
+
if generation_config:
|
|
180
|
+
final_cfg.update(generation_config or {})
|
|
181
|
+
|
|
182
|
+
# Style instructions: prepend if provided
|
|
183
|
+
style_instructions = final_cfg.get("style_instructions")
|
|
184
|
+
final_prompt = f"{style_instructions}: {prompt}" if style_instructions else prompt
|
|
185
|
+
logger.debug("Final prompt for TTS (truncated): '%s...'", final_prompt[:160])
|
|
186
|
+
|
|
187
|
+
# Mode & voice
|
|
188
|
+
mode = final_cfg.get("mode", "single-speaker")
|
|
189
|
+
default_voice = final_cfg.get("voice_name", "Kore")
|
|
190
|
+
|
|
191
|
+
if mode == "multi-speaker":
|
|
192
|
+
speakers = final_cfg.get("speakers")
|
|
193
|
+
if not speakers or not isinstance(speakers, list):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"For multi-speaker mode, provide generation_config['speakers'] "
|
|
196
|
+
"as a list of {'speaker': <name>, 'voice_name': <prebuilt voice>}."
|
|
197
|
+
)
|
|
198
|
+
gen_config = self._build_multi_speaker_generation_config(speakers)
|
|
199
|
+
else:
|
|
200
|
+
gen_config = self._build_single_speaker_generation_config(default_voice)
|
|
201
|
+
|
|
202
|
+
# Run the blocking gen call in a thread so this coroutine stays non-blocking
|
|
203
|
+
response = await asyncio.to_thread(
|
|
204
|
+
self._model.generate_content,
|
|
205
|
+
final_prompt,
|
|
206
|
+
generation_config=gen_config,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
audio_pcm = _extract_inline_audio_bytes(response)
|
|
210
|
+
audio_path = _save_audio_bytes_to_wav(audio_pcm)
|
|
211
|
+
|
|
212
|
+
return SpeechGenerationResponse(audio_urls=[audio_path])
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error("Error during Google Gemini speech generation (legacy SDK): %s", str(e))
|
|
216
|
+
raise ValueError(f"Google Gemini speech generation failed: {str(e)}")
|
|
217
|
+
|
|
218
|
+
async def cleanup(self):
|
|
219
|
+
logger.debug("GeminiAudioClient cleanup called (legacy SDK; nothing to release).")
|