voicerun_completions 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- voicerun_completions/__init__.py +0 -0
- voicerun_completions/client.py +165 -0
- voicerun_completions/providers/anthropic/anthropic_client.py +192 -0
- voicerun_completions/providers/anthropic/streaming.py +197 -0
- voicerun_completions/providers/anthropic/utils.py +193 -0
- voicerun_completions/providers/base.py +145 -0
- voicerun_completions/providers/google/google_client.py +165 -0
- voicerun_completions/providers/google/streaming.py +177 -0
- voicerun_completions/providers/google/utils.py +142 -0
- voicerun_completions/providers/openai/openai_client.py +159 -0
- voicerun_completions/providers/openai/streaming.py +182 -0
- voicerun_completions/providers/openai/utils.py +135 -0
- voicerun_completions-0.1.2.dist-info/METADATA +46 -0
- voicerun_completions-0.1.2.dist-info/RECORD +16 -0
- voicerun_completions-0.1.2.dist-info/WHEEL +5 -0
- voicerun_completions-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Optional, List, Union
|
|
3
|
+
from anthropic.types import (
|
|
4
|
+
MessageParam as AnthropicMessage,
|
|
5
|
+
ToolParam as AnthropicToolDefinition,
|
|
6
|
+
ToolChoiceParam as AnthropicToolChoice,
|
|
7
|
+
ContentBlock as AnthropicContentBlock,
|
|
8
|
+
TextBlockParam as AnthropicTextBlock,
|
|
9
|
+
ToolUseBlockParam as AnthropicToolCall,
|
|
10
|
+
ToolResultBlockParam as AnthropicToolResult,
|
|
11
|
+
ToolChoiceAutoParam as AnthropicToolChoiceAuto,
|
|
12
|
+
ToolChoiceAnyParam as AnthropicToolChoiceAny,
|
|
13
|
+
ToolChoiceNoneParam as AnthropicToolChoiceNone,
|
|
14
|
+
ToolChoiceToolParam as AnthropicToolChoiceToolName,
|
|
15
|
+
CacheControlEphemeralParam as AnthropicCacheControl,
|
|
16
|
+
)
|
|
17
|
+
from primfunctions.completions.messages import (
|
|
18
|
+
ConversationHistory,
|
|
19
|
+
UserMessage,
|
|
20
|
+
AssistantMessage,
|
|
21
|
+
SystemMessage,
|
|
22
|
+
ToolResultMessage,
|
|
23
|
+
ToolCall,
|
|
24
|
+
)
|
|
25
|
+
from primfunctions.completions.request import ToolChoice, ToolDefinition
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def denormalize_tool_calls(normalized_tool_calls: Optional[List[ToolCall]]) -> list[AnthropicToolCall]:
|
|
29
|
+
"""TODO"""
|
|
30
|
+
|
|
31
|
+
tool_calls: Optional[list[AnthropicToolCall]] = None
|
|
32
|
+
if normalized_tool_calls:
|
|
33
|
+
tool_calls = []
|
|
34
|
+
for tc in normalized_tool_calls:
|
|
35
|
+
tool_call: AnthropicToolCall = {
|
|
36
|
+
"type": "tool_use",
|
|
37
|
+
"id": tc.id,
|
|
38
|
+
"name": tc.function.name,
|
|
39
|
+
"input": tc.function.arguments,
|
|
40
|
+
}
|
|
41
|
+
tool_calls.append(tool_call)
|
|
42
|
+
|
|
43
|
+
return tool_calls
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def denormalize_conversation_history(normalized_messages: ConversationHistory) -> tuple[list[AnthropicMessage], Optional[list[AnthropicTextBlock]]]:
|
|
47
|
+
"""Convert normalized Message objects to Anthropic MessageParam format.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Tuple of (messages, system_prompt) since Anthropic handles system prompts separately
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
messages: list[AnthropicMessage] = []
|
|
54
|
+
system_messages: list[AnthropicTextBlock] = []
|
|
55
|
+
|
|
56
|
+
i = 0
|
|
57
|
+
while i < len(normalized_messages):
|
|
58
|
+
msg = normalized_messages[i]
|
|
59
|
+
|
|
60
|
+
match msg:
|
|
61
|
+
case UserMessage():
|
|
62
|
+
# User messages: simple content handling
|
|
63
|
+
user_message: AnthropicMessage = {
|
|
64
|
+
"role": "user",
|
|
65
|
+
"content": [
|
|
66
|
+
AnthropicTextBlock(
|
|
67
|
+
type="text",
|
|
68
|
+
text=msg.content or "",
|
|
69
|
+
cache_control=AnthropicCacheControl(
|
|
70
|
+
type="ephemeral",
|
|
71
|
+
ttl=msg.cache_breakpoint.ttl,
|
|
72
|
+
) if msg.cache_breakpoint else None
|
|
73
|
+
)
|
|
74
|
+
],
|
|
75
|
+
}
|
|
76
|
+
messages.append(user_message)
|
|
77
|
+
i += 1
|
|
78
|
+
case AssistantMessage():
|
|
79
|
+
# Build content blocks
|
|
80
|
+
content_blocks: list[AnthropicContentBlock] = []
|
|
81
|
+
|
|
82
|
+
# Add text content if present
|
|
83
|
+
if msg.content:
|
|
84
|
+
content_blocks.append(AnthropicTextBlock(
|
|
85
|
+
type="text",
|
|
86
|
+
text=msg.content,
|
|
87
|
+
))
|
|
88
|
+
|
|
89
|
+
# Add tool calls
|
|
90
|
+
if msg.tool_calls:
|
|
91
|
+
tool_calls = denormalize_tool_calls(msg.tool_calls)
|
|
92
|
+
if tool_calls:
|
|
93
|
+
content_blocks.extend(tool_calls)
|
|
94
|
+
|
|
95
|
+
# Capture cache control of assistant message
|
|
96
|
+
cache_control = AnthropicCacheControl(
|
|
97
|
+
type="ephemeral",
|
|
98
|
+
ttl=msg.cache_breakpoint.ttl,
|
|
99
|
+
) if msg.cache_breakpoint else None
|
|
100
|
+
|
|
101
|
+
# Add cache control to final block if present and blocks exist
|
|
102
|
+
if cache_control and content_blocks:
|
|
103
|
+
final_block: Union[AnthropicTextBlock, AnthropicToolCall] = content_blocks[-1]
|
|
104
|
+
final_block["cache_control"] = cache_control
|
|
105
|
+
|
|
106
|
+
assistant_message: AnthropicMessage = {
|
|
107
|
+
"role": "assistant",
|
|
108
|
+
"content": content_blocks
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
messages.append(assistant_message)
|
|
112
|
+
i += 1
|
|
113
|
+
case SystemMessage():
|
|
114
|
+
# Collect all system and developer messages
|
|
115
|
+
system_messages.append(
|
|
116
|
+
AnthropicTextBlock(
|
|
117
|
+
type="text",
|
|
118
|
+
text=msg.content or "",
|
|
119
|
+
cache_control=AnthropicCacheControl(
|
|
120
|
+
type="ephemeral",
|
|
121
|
+
ttl=msg.cache_breakpoint.ttl,
|
|
122
|
+
) if msg.cache_breakpoint else None
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
i += 1
|
|
126
|
+
case ToolResultMessage():
|
|
127
|
+
# Look ahead to group consecutive tool messages into a single user message
|
|
128
|
+
tool_results: list[AnthropicToolResult] = []
|
|
129
|
+
while i < len(normalized_messages) and normalized_messages[i].role == "tool":
|
|
130
|
+
tool_msg: ToolResultMessage = normalized_messages[i]
|
|
131
|
+
tool_results.append(AnthropicToolResult(
|
|
132
|
+
type="tool_result",
|
|
133
|
+
tool_use_id=tool_msg.tool_call_id,
|
|
134
|
+
content=json.dumps(tool_msg.content),
|
|
135
|
+
cache_control=AnthropicCacheControl(
|
|
136
|
+
type="ephemeral",
|
|
137
|
+
ttl=tool_msg.cache_breakpoint.ttl,
|
|
138
|
+
) if tool_msg.cache_breakpoint else None
|
|
139
|
+
))
|
|
140
|
+
i += 1
|
|
141
|
+
|
|
142
|
+
# Create user message with all tool results
|
|
143
|
+
user_message: AnthropicMessage = {
|
|
144
|
+
"role": "user",
|
|
145
|
+
"content": tool_results
|
|
146
|
+
}
|
|
147
|
+
messages.append(user_message)
|
|
148
|
+
case _:
|
|
149
|
+
# Skip unsupported roles
|
|
150
|
+
i += 1
|
|
151
|
+
|
|
152
|
+
# Capture system prompt as list of TextBlocks
|
|
153
|
+
system_prompt: list[AnthropicTextBlock] = system_messages or None
|
|
154
|
+
|
|
155
|
+
return messages, system_prompt
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def denormalize_tools(normalized_tools: Optional[list[ToolDefinition]]) -> Optional[list[AnthropicToolDefinition]]:
|
|
159
|
+
"""Convert normalized Tool objects to Anthropic ToolParam format."""
|
|
160
|
+
|
|
161
|
+
tools: Optional[list[AnthropicToolDefinition]] = None
|
|
162
|
+
if normalized_tools:
|
|
163
|
+
tools = []
|
|
164
|
+
for tool in normalized_tools:
|
|
165
|
+
tool_def: AnthropicToolDefinition = {
|
|
166
|
+
"name": tool.function.name,
|
|
167
|
+
"description": tool.function.description,
|
|
168
|
+
"input_schema": tool.function.parameters,
|
|
169
|
+
}
|
|
170
|
+
tools.append(tool_def)
|
|
171
|
+
|
|
172
|
+
return tools
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def denormalize_tool_choice(normalized_tool_choice: Optional[ToolChoice]) -> Optional[AnthropicToolChoice]:
|
|
176
|
+
"""Convert normalized ToolChoice to Anthropic ToolChoiceParam format."""
|
|
177
|
+
|
|
178
|
+
if not normalized_tool_choice:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Map our normalized tool choice to Anthropic's format using proper SDK types
|
|
182
|
+
if normalized_tool_choice == "auto":
|
|
183
|
+
return AnthropicToolChoiceAuto(type="auto")
|
|
184
|
+
elif normalized_tool_choice == "none":
|
|
185
|
+
return AnthropicToolChoiceNone(type="none")
|
|
186
|
+
elif normalized_tool_choice == "required":
|
|
187
|
+
return AnthropicToolChoiceAny(type="any") # Anthropic uses "any" for required
|
|
188
|
+
else:
|
|
189
|
+
# Specific tool name - use ToolChoiceToolParam
|
|
190
|
+
return AnthropicToolChoiceToolName(
|
|
191
|
+
type="tool",
|
|
192
|
+
name=normalized_tool_choice
|
|
193
|
+
)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, AsyncIterable, Optional
|
|
5
|
+
from primfunctions.completions.streaming import ChatCompletionChunk
|
|
6
|
+
from primfunctions.completions.response import ChatCompletionResponse
|
|
7
|
+
from primfunctions.completions.request import ChatCompletionRequest, StreamOptions
|
|
8
|
+
from primfunctions.completions.messages import ToolCall, FunctionCall
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PartialToolCall:
|
|
13
|
+
"""Internal state for accumulating tool call data."""
|
|
14
|
+
id: str
|
|
15
|
+
type: str
|
|
16
|
+
function_name: str
|
|
17
|
+
arguments_buffer: str = ""
|
|
18
|
+
index: Optional[int] = None
|
|
19
|
+
|
|
20
|
+
def is_complete(self) -> bool:
|
|
21
|
+
"""Check if the accumulated arguments form valid JSON."""
|
|
22
|
+
# TODO: better way of doing this?
|
|
23
|
+
try:
|
|
24
|
+
json.loads(self.arguments_buffer)
|
|
25
|
+
return True
|
|
26
|
+
except json.JSONDecodeError:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
def to_tool_call(self) -> ToolCall:
|
|
30
|
+
"""Convert this partial to ToolCall. Empty arguments if invalid json."""
|
|
31
|
+
|
|
32
|
+
arguments = {}
|
|
33
|
+
try:
|
|
34
|
+
arguments = json.loads(self.arguments_buffer)
|
|
35
|
+
except:
|
|
36
|
+
# Received invalid json arguments
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
return ToolCall(
|
|
40
|
+
id=self.id,
|
|
41
|
+
type=self.type,
|
|
42
|
+
function=FunctionCall(
|
|
43
|
+
name=self.function_name,
|
|
44
|
+
arguments=arguments,
|
|
45
|
+
),
|
|
46
|
+
index=self.index,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class StreamProcessor(ABC):
|
|
51
|
+
"""Processes LLM provider specific completion stream and returns ChatCompletionChunks to yield to client.
|
|
52
|
+
Accumulates streaming tool call deltas and emits complete tool calls.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def process_stream(
|
|
57
|
+
self,
|
|
58
|
+
stream: AsyncIterable[Any]
|
|
59
|
+
) -> AsyncIterable[ChatCompletionChunk]:
|
|
60
|
+
"""Process a stream of provider completion chunks and yield normalized chunks."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class CompletionClient(ABC):
|
|
65
|
+
"""Abstract base class for LLM completion clients."""
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def _denormalize_request(
|
|
69
|
+
self,
|
|
70
|
+
request: ChatCompletionRequest,
|
|
71
|
+
) -> dict[str, Any]:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def _normalize_response(
|
|
77
|
+
self,
|
|
78
|
+
response: Any,
|
|
79
|
+
) -> ChatCompletionResponse:
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
async def _get_completion(
|
|
85
|
+
self,
|
|
86
|
+
**kwargs
|
|
87
|
+
) -> Any:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
async def _get_completion_stream(
|
|
93
|
+
self,
|
|
94
|
+
**kwargs
|
|
95
|
+
) -> AsyncIterable[Any]:
|
|
96
|
+
"""Get streaming completion from provider."""
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def _get_stream_processor(
|
|
102
|
+
self,
|
|
103
|
+
stream_options: Optional[StreamOptions] = None,
|
|
104
|
+
) -> StreamProcessor:
|
|
105
|
+
"""Get provider-specific StreamProcessor."""
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def generate_chat_completion(
|
|
110
|
+
self,
|
|
111
|
+
request: ChatCompletionRequest,
|
|
112
|
+
) -> ChatCompletionResponse:
|
|
113
|
+
"""
|
|
114
|
+
Generate chat completion.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
request: Normalized chat completion request
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
ChatCompletionResponse with normalized data
|
|
121
|
+
"""
|
|
122
|
+
denormalized_request = self._denormalize_request(request)
|
|
123
|
+
completion = await self._get_completion(**denormalized_request)
|
|
124
|
+
return self._normalize_response(completion)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
async def generate_chat_completion_stream(
|
|
128
|
+
self,
|
|
129
|
+
request: ChatCompletionRequest,
|
|
130
|
+
) -> AsyncIterable[ChatCompletionChunk]:
|
|
131
|
+
"""
|
|
132
|
+
Generate streaming chat completion.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
request: Normalized chat completion request with streaming=True
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
AsyncIterable of ChatCompletionStreamChunk with normalized data
|
|
139
|
+
"""
|
|
140
|
+
denormalized_request = self._denormalize_request(request)
|
|
141
|
+
processor = self._get_stream_processor(request.stream_options)
|
|
142
|
+
completion_stream = await self._get_completion_stream(**denormalized_request)
|
|
143
|
+
|
|
144
|
+
async for normalized_chunk in processor.process_stream(completion_stream):
|
|
145
|
+
yield normalized_chunk
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Any, AsyncIterator, Optional
|
|
2
|
+
from google.genai import Client
|
|
3
|
+
from google.genai.types import (
|
|
4
|
+
Candidate as GoogleResponseCandidate,
|
|
5
|
+
Content as GoogleMessageContent,
|
|
6
|
+
Part as GoogleMessagePart,
|
|
7
|
+
GenerateContentResponse as GoogleContentResponse,
|
|
8
|
+
GenerateContentConfigDict as GoogleRequestConfigDict,
|
|
9
|
+
ThinkingConfigDict as GoogleThinkingConfigDict,
|
|
10
|
+
)
|
|
11
|
+
from primfunctions.completions.messages import AssistantMessage, FunctionCall, ToolCall
|
|
12
|
+
from primfunctions.completions.response import ChatCompletionResponse
|
|
13
|
+
from primfunctions.completions.request import ChatCompletionRequest, StreamOptions
|
|
14
|
+
|
|
15
|
+
from ..base import CompletionClient
|
|
16
|
+
from .utils import denormalize_conversation_history, denormalize_tools, denormalize_tool_choice
|
|
17
|
+
from .streaming import GoogleStreamProcessor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GoogleCompletionClient(CompletionClient):
|
|
21
|
+
|
|
22
|
+
def _denormalize_request(
|
|
23
|
+
self,
|
|
24
|
+
request: ChatCompletionRequest,
|
|
25
|
+
) -> dict[str, Any]:
|
|
26
|
+
"""Convert ChatCompletionRequest to kwargs for _get_completion."""
|
|
27
|
+
|
|
28
|
+
contents, system_instruction = denormalize_conversation_history(request.messages)
|
|
29
|
+
|
|
30
|
+
kwargs = {
|
|
31
|
+
"api_key": request.api_key,
|
|
32
|
+
"model": request.model,
|
|
33
|
+
"contents": contents,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# Build config using GenerateContentConfigDict
|
|
37
|
+
config_dict = {}
|
|
38
|
+
|
|
39
|
+
if request.temperature is not None:
|
|
40
|
+
config_dict["temperature"] = request.temperature
|
|
41
|
+
|
|
42
|
+
if request.max_tokens is not None:
|
|
43
|
+
config_dict["max_output_tokens"] = request.max_tokens
|
|
44
|
+
|
|
45
|
+
if request.tools:
|
|
46
|
+
config_dict["tools"] = denormalize_tools(request.tools)
|
|
47
|
+
|
|
48
|
+
if request.tool_choice:
|
|
49
|
+
tool_config = denormalize_tool_choice(request.tool_choice)
|
|
50
|
+
if tool_config:
|
|
51
|
+
config_dict["tool_config"] = tool_config
|
|
52
|
+
|
|
53
|
+
if request.timeout is not None:
|
|
54
|
+
# Google takes timeout in ms
|
|
55
|
+
timeout_ms: int = request.timeout * 1000
|
|
56
|
+
config_dict["http_options"] = {
|
|
57
|
+
"timeout": timeout_ms
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Add system instruction if present
|
|
61
|
+
if system_instruction:
|
|
62
|
+
config_dict["system_instruction"] = system_instruction
|
|
63
|
+
|
|
64
|
+
# Disable thinking to improve streaming latency
|
|
65
|
+
config_dict["thinking_config"] = GoogleThinkingConfigDict(
|
|
66
|
+
thinking_budget=0, # 0 = DISABLED
|
|
67
|
+
include_thoughts=False
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Only add config if it has values
|
|
71
|
+
if config_dict:
|
|
72
|
+
kwargs["config"] = GoogleRequestConfigDict(**config_dict)
|
|
73
|
+
|
|
74
|
+
return kwargs
|
|
75
|
+
|
|
76
|
+
def _normalize_response(
|
|
77
|
+
self,
|
|
78
|
+
response: GoogleContentResponse,
|
|
79
|
+
) -> ChatCompletionResponse:
|
|
80
|
+
|
|
81
|
+
# Take only first candidate
|
|
82
|
+
candidate: GoogleResponseCandidate = response.candidates[0]
|
|
83
|
+
google_message: GoogleMessageContent = candidate.content
|
|
84
|
+
google_message_parts: list[GoogleMessagePart] = google_message.parts
|
|
85
|
+
|
|
86
|
+
# Extract text and function calls from parts
|
|
87
|
+
text_parts: list[str] = []
|
|
88
|
+
tool_calls: list[ToolCall] = []
|
|
89
|
+
tool_call_index: int = 0
|
|
90
|
+
|
|
91
|
+
for part in google_message_parts:
|
|
92
|
+
if part.text:
|
|
93
|
+
text_parts.append(part.text)
|
|
94
|
+
elif part.function_call:
|
|
95
|
+
# Convert function call to ToolCall
|
|
96
|
+
tool_calls.append(ToolCall(
|
|
97
|
+
id=part.function_call.id,
|
|
98
|
+
type="function",
|
|
99
|
+
function=FunctionCall(
|
|
100
|
+
name=part.function_call.name,
|
|
101
|
+
arguments=part.function_call.args
|
|
102
|
+
),
|
|
103
|
+
index=tool_call_index,
|
|
104
|
+
))
|
|
105
|
+
tool_call_index += 1
|
|
106
|
+
|
|
107
|
+
# Combine text content
|
|
108
|
+
content = "".join(text_parts) if text_parts else None
|
|
109
|
+
|
|
110
|
+
normalized_message = AssistantMessage(
|
|
111
|
+
content=content,
|
|
112
|
+
tool_calls=tool_calls if tool_calls else None
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return ChatCompletionResponse(
|
|
116
|
+
message=normalized_message,
|
|
117
|
+
finish_reason=candidate.finish_reason,
|
|
118
|
+
usage=response.usage_metadata.model_dump() if response.usage_metadata else None
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def _get_completion(
|
|
122
|
+
self,
|
|
123
|
+
api_key: str,
|
|
124
|
+
model: str,
|
|
125
|
+
contents: list[GoogleMessageContent],
|
|
126
|
+
config: Optional[GoogleRequestConfigDict] = None,
|
|
127
|
+
) -> GoogleContentResponse:
|
|
128
|
+
"""Generate content using Google Gen AI client.
|
|
129
|
+
|
|
130
|
+
[Client](https://github.com/googleapis/python-genai)
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
async with Client(api_key=api_key).aio as async_client:
|
|
134
|
+
return await async_client.models.generate_content(
|
|
135
|
+
model=model,
|
|
136
|
+
contents=contents,
|
|
137
|
+
config=config,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _get_stream_processor(
|
|
142
|
+
self,
|
|
143
|
+
stream_options: Optional[StreamOptions] = None,
|
|
144
|
+
) -> GoogleStreamProcessor:
|
|
145
|
+
"""Get anthropic-specific StreamProcessor."""
|
|
146
|
+
return GoogleStreamProcessor(stream_options=stream_options)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def _get_completion_stream(
|
|
150
|
+
self,
|
|
151
|
+
api_key: str,
|
|
152
|
+
model: str,
|
|
153
|
+
contents: list[GoogleMessageContent],
|
|
154
|
+
config: Optional[GoogleRequestConfigDict] = None,
|
|
155
|
+
) -> AsyncIterator[GoogleContentResponse]:
|
|
156
|
+
"""Stream chat completion events from Google.
|
|
157
|
+
|
|
158
|
+
[Client](https://github.com/googleapis/python-genai)
|
|
159
|
+
"""
|
|
160
|
+
client = Client(api_key=api_key)
|
|
161
|
+
return await client.aio.models.generate_content_stream(
|
|
162
|
+
model=model,
|
|
163
|
+
contents=contents,
|
|
164
|
+
config=config,
|
|
165
|
+
)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
from typing import Any, AsyncIterable, Dict, List, Optional
|
|
3
|
+
from google.genai.types import (
|
|
4
|
+
Candidate as GoogleResponseCandidate,
|
|
5
|
+
GenerateContentResponse as GoogleResponseChunk,
|
|
6
|
+
Content as GoogleMessageContent,
|
|
7
|
+
Part as GoogleMessagePart,
|
|
8
|
+
)
|
|
9
|
+
from primfunctions.completions.messages import ToolCall, FunctionCall, AssistantMessage
|
|
10
|
+
from primfunctions.completions.response import ChatCompletionResponse
|
|
11
|
+
from primfunctions.completions.streaming import (
|
|
12
|
+
ChatCompletionChunk,
|
|
13
|
+
AssistantMessageDeltaChunk,
|
|
14
|
+
AssistantMessageSentenceChunk,
|
|
15
|
+
FinishReasonChunk,
|
|
16
|
+
ToolCallChunk,
|
|
17
|
+
UsageChunk,
|
|
18
|
+
FinalResponseChunk,
|
|
19
|
+
)
|
|
20
|
+
from primfunctions.utils.streaming import update_sentence_buffer, clean_text_for_speech
|
|
21
|
+
from primfunctions.completions.request import StreamOptions
|
|
22
|
+
|
|
23
|
+
from ..base import StreamProcessor, PartialToolCall
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GoogleStreamProcessor(StreamProcessor):
|
|
27
|
+
"""Processes Google message stream events yielding normalized chunks."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
stream_options: Optional[StreamOptions] = None,
|
|
32
|
+
):
|
|
33
|
+
self.stream_sentences: bool = False
|
|
34
|
+
self.clean_sentences: bool = True
|
|
35
|
+
self.min_sentence_length: int = 6
|
|
36
|
+
self.punctuation_marks: Optional[list[str]] = None
|
|
37
|
+
self.punctuation_language: Optional[str] = None
|
|
38
|
+
|
|
39
|
+
# Override stream options defaults
|
|
40
|
+
if stream_options:
|
|
41
|
+
self.stream_sentences = stream_options.stream_sentences
|
|
42
|
+
self.clean_sentences = stream_options.clean_sentences
|
|
43
|
+
self.min_sentence_length = stream_options.min_sentence_length
|
|
44
|
+
self.punctuation_marks = stream_options.punctuation_marks
|
|
45
|
+
self.punctuation_language = stream_options.punctuation_language
|
|
46
|
+
|
|
47
|
+
self.active_calls: Dict[int, PartialToolCall] = {}
|
|
48
|
+
self.content: str = ""
|
|
49
|
+
self.tool_calls: List[ToolCall] = []
|
|
50
|
+
self.finish_reason: str = ""
|
|
51
|
+
self.usage: Dict[str, Any] = {}
|
|
52
|
+
self.sentence_buffer = ""
|
|
53
|
+
self.current_block_index = 0
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _process_chunk(
|
|
57
|
+
self,
|
|
58
|
+
chunk: GoogleResponseChunk,
|
|
59
|
+
) -> List[ChatCompletionChunk]:
|
|
60
|
+
"""Convert Google ContentResponse to individual typed chunks."""
|
|
61
|
+
chunks = []
|
|
62
|
+
|
|
63
|
+
# Handle usage information
|
|
64
|
+
if chunk.usage_metadata:
|
|
65
|
+
# TODO: ensure cumulative
|
|
66
|
+
self.usage = chunk.usage_metadata.model_dump()
|
|
67
|
+
|
|
68
|
+
# Handle candidates
|
|
69
|
+
if not chunk.candidates:
|
|
70
|
+
return chunks
|
|
71
|
+
|
|
72
|
+
# Take only first candidate (Google pattern)
|
|
73
|
+
candidate: GoogleResponseCandidate = chunk.candidates[0]
|
|
74
|
+
content: GoogleMessageContent = candidate.content
|
|
75
|
+
parts: list[GoogleMessagePart] = content.parts
|
|
76
|
+
|
|
77
|
+
# Handle finish reason
|
|
78
|
+
if candidate.finish_reason:
|
|
79
|
+
self.finish_reason = candidate.finish_reason.value.lower()
|
|
80
|
+
|
|
81
|
+
# Process each part in the content
|
|
82
|
+
for part in parts:
|
|
83
|
+
# Handle text content
|
|
84
|
+
if part.text:
|
|
85
|
+
chunks.extend(self._process_text_partial(part.text))
|
|
86
|
+
|
|
87
|
+
# Handle function calls (tool calls)
|
|
88
|
+
elif part.function_call:
|
|
89
|
+
# Google returns complete function calls, not deltas
|
|
90
|
+
tool_call = ToolCall(
|
|
91
|
+
id=part.function_call.id or f"call_{self.current_block_index}",
|
|
92
|
+
type="function",
|
|
93
|
+
function=FunctionCall(
|
|
94
|
+
name=part.function_call.name,
|
|
95
|
+
arguments=part.function_call.args,
|
|
96
|
+
),
|
|
97
|
+
index=self.current_block_index,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
chunks.append(ToolCallChunk(tool_call=tool_call))
|
|
101
|
+
self.tool_calls.append(tool_call)
|
|
102
|
+
self.current_block_index += 1
|
|
103
|
+
|
|
104
|
+
return chunks
|
|
105
|
+
|
|
106
|
+
def _process_text_partial(self, text: str) -> List[ChatCompletionChunk]:
|
|
107
|
+
"""Process text content and return appropriate chunks."""
|
|
108
|
+
chunks = []
|
|
109
|
+
|
|
110
|
+
if self.stream_sentences:
|
|
111
|
+
# Append delta to sentence buffer
|
|
112
|
+
sentence_buffer, complete_sentence = update_sentence_buffer(
|
|
113
|
+
content=text,
|
|
114
|
+
sentence_buffer=self.sentence_buffer,
|
|
115
|
+
punctuation_marks=self.punctuation_marks,
|
|
116
|
+
clean_text=self.clean_sentences,
|
|
117
|
+
min_sentence_length=self.min_sentence_length,
|
|
118
|
+
)
|
|
119
|
+
self.sentence_buffer = sentence_buffer
|
|
120
|
+
|
|
121
|
+
if complete_sentence:
|
|
122
|
+
chunks.append(AssistantMessageSentenceChunk(
|
|
123
|
+
sentence=complete_sentence
|
|
124
|
+
))
|
|
125
|
+
else:
|
|
126
|
+
# Otherwise stream content delta directly
|
|
127
|
+
chunks.append(AssistantMessageDeltaChunk(
|
|
128
|
+
content=text
|
|
129
|
+
))
|
|
130
|
+
|
|
131
|
+
# Add content delta to accumulated response
|
|
132
|
+
self.content += text
|
|
133
|
+
|
|
134
|
+
return chunks
|
|
135
|
+
|
|
136
|
+
async def process_stream(
|
|
137
|
+
self,
|
|
138
|
+
stream: AsyncIterable[GoogleResponseChunk],
|
|
139
|
+
) -> AsyncIterable[ChatCompletionChunk]:
|
|
140
|
+
"""Process Google event stream and yield normalized chunks."""
|
|
141
|
+
try:
|
|
142
|
+
async for event in stream:
|
|
143
|
+
for chunk in self._process_chunk(event):
|
|
144
|
+
yield chunk
|
|
145
|
+
except Exception as e:
|
|
146
|
+
# Handle Google API connection issues that occur after successful content delivery
|
|
147
|
+
if isinstance(e, httpx.ReadError) and str(e) == '':
|
|
148
|
+
# This is a known issue with Google's streaming API - ignore empty ReadErrors
|
|
149
|
+
# that occur after content has been successfully delivered
|
|
150
|
+
pass
|
|
151
|
+
else:
|
|
152
|
+
raise
|
|
153
|
+
|
|
154
|
+
# Handle remaining sentence buffer if streaming sentences
|
|
155
|
+
if self.stream_sentences and self.sentence_buffer:
|
|
156
|
+
complete_sentence = clean_text_for_speech(self.sentence_buffer) if self.clean_sentences else self.sentence_buffer
|
|
157
|
+
yield AssistantMessageSentenceChunk(
|
|
158
|
+
sentence=complete_sentence
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Yield the finish chunk (or default)
|
|
162
|
+
yield FinishReasonChunk(finish_reason=self.finish_reason or "stop")
|
|
163
|
+
|
|
164
|
+
# Yield the usage chunk
|
|
165
|
+
yield UsageChunk(usage=self.usage)
|
|
166
|
+
|
|
167
|
+
# Yield aggregated chat completion response as final chunk
|
|
168
|
+
yield FinalResponseChunk(
|
|
169
|
+
response=ChatCompletionResponse(
|
|
170
|
+
message=AssistantMessage(
|
|
171
|
+
content=self.content or None,
|
|
172
|
+
tool_calls=self.tool_calls or None,
|
|
173
|
+
),
|
|
174
|
+
usage=self.usage,
|
|
175
|
+
finish_reason=self.finish_reason,
|
|
176
|
+
)
|
|
177
|
+
)
|