agnt5 0.1.0__cp39-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agnt5/__init__.py +307 -0
- agnt5/__pycache__/__init__.cpython-311.pyc +0 -0
- agnt5/__pycache__/agent.cpython-311.pyc +0 -0
- agnt5/__pycache__/context.cpython-311.pyc +0 -0
- agnt5/__pycache__/durable.cpython-311.pyc +0 -0
- agnt5/__pycache__/extraction.cpython-311.pyc +0 -0
- agnt5/__pycache__/memory.cpython-311.pyc +0 -0
- agnt5/__pycache__/reflection.cpython-311.pyc +0 -0
- agnt5/__pycache__/runtime.cpython-311.pyc +0 -0
- agnt5/__pycache__/task.cpython-311.pyc +0 -0
- agnt5/__pycache__/tool.cpython-311.pyc +0 -0
- agnt5/__pycache__/tracing.cpython-311.pyc +0 -0
- agnt5/__pycache__/types.cpython-311.pyc +0 -0
- agnt5/__pycache__/workflow.cpython-311.pyc +0 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/agent.py +1086 -0
- agnt5/context.py +406 -0
- agnt5/durable.py +1050 -0
- agnt5/extraction.py +410 -0
- agnt5/llm/__init__.py +179 -0
- agnt5/llm/__pycache__/__init__.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/anthropic.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/azure.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/base.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/google.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/mistral.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/openai.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/together.cpython-311.pyc +0 -0
- agnt5/llm/anthropic.py +319 -0
- agnt5/llm/azure.py +348 -0
- agnt5/llm/base.py +315 -0
- agnt5/llm/google.py +373 -0
- agnt5/llm/mistral.py +330 -0
- agnt5/llm/model_registry.py +467 -0
- agnt5/llm/models.json +227 -0
- agnt5/llm/openai.py +334 -0
- agnt5/llm/together.py +377 -0
- agnt5/memory.py +746 -0
- agnt5/reflection.py +514 -0
- agnt5/runtime.py +699 -0
- agnt5/task.py +476 -0
- agnt5/testing.py +451 -0
- agnt5/tool.py +516 -0
- agnt5/tracing.py +624 -0
- agnt5/types.py +210 -0
- agnt5/workflow.py +897 -0
- agnt5-0.1.0.dist-info/METADATA +93 -0
- agnt5-0.1.0.dist-info/RECORD +49 -0
- agnt5-0.1.0.dist-info/WHEEL +4 -0
agnt5/llm/base.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes and interfaces for LLM provider integration.
|
|
3
|
+
|
|
4
|
+
Defines the common interface that all LLM providers must implement,
|
|
5
|
+
along with shared types and utilities.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Role(Enum):
|
|
17
|
+
"""Message roles in conversations."""
|
|
18
|
+
SYSTEM = "system"
|
|
19
|
+
USER = "user"
|
|
20
|
+
ASSISTANT = "assistant"
|
|
21
|
+
TOOL = "tool"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LanguageModelType:
|
|
25
|
+
"""
|
|
26
|
+
Simple model type that parses provider/model format.
|
|
27
|
+
|
|
28
|
+
Supports formats:
|
|
29
|
+
- "provider/model" (e.g., "anthropic/claude-3-5-sonnet")
|
|
30
|
+
- "model" (auto-detects provider, e.g., "gpt-4o")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, model_string: str):
|
|
34
|
+
"""
|
|
35
|
+
Initialize with a model string.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model_string: Either "provider/model" or just "model"
|
|
39
|
+
"""
|
|
40
|
+
self.original = model_string
|
|
41
|
+
|
|
42
|
+
if "/" in model_string:
|
|
43
|
+
# Format: "provider/model"
|
|
44
|
+
self._provider, self.value = model_string.split("/", 1)
|
|
45
|
+
else:
|
|
46
|
+
# Format: "model" - auto-detect provider
|
|
47
|
+
self.value = model_string
|
|
48
|
+
self._provider = self._detect_provider(model_string)
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_string(cls, model_string: str) -> "LanguageModelType":
|
|
52
|
+
"""Create LanguageModelType from string."""
|
|
53
|
+
return cls(model_string)
|
|
54
|
+
|
|
55
|
+
def _detect_provider(self, model_name: str) -> str:
|
|
56
|
+
"""Auto-detect provider from model name."""
|
|
57
|
+
model_lower = model_name.lower()
|
|
58
|
+
|
|
59
|
+
# Claude models
|
|
60
|
+
if any(keyword in model_lower for keyword in ["claude", "sonnet", "haiku", "opus"]):
|
|
61
|
+
return "anthropic"
|
|
62
|
+
|
|
63
|
+
# OpenAI models
|
|
64
|
+
elif model_lower.startswith("gpt"):
|
|
65
|
+
return "openai"
|
|
66
|
+
|
|
67
|
+
# Google models
|
|
68
|
+
elif model_lower.startswith("gemini"):
|
|
69
|
+
return "google"
|
|
70
|
+
|
|
71
|
+
# Mistral models
|
|
72
|
+
elif any(model_lower.startswith(prefix) for prefix in ["mistral", "codestral"]):
|
|
73
|
+
return "mistral"
|
|
74
|
+
|
|
75
|
+
# Together AI models (usually have namespace/model format, but handle edge cases)
|
|
76
|
+
elif any(keyword in model_lower for keyword in ["llama", "mixtral", "qwen"]) or model_name.count("/") > 0:
|
|
77
|
+
return "together"
|
|
78
|
+
|
|
79
|
+
# Default to OpenAI for unknown models
|
|
80
|
+
else:
|
|
81
|
+
return "openai"
|
|
82
|
+
|
|
83
|
+
def get_provider(self) -> str:
|
|
84
|
+
"""Get the provider name for this model."""
|
|
85
|
+
return self._provider
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def provider(self) -> str:
|
|
89
|
+
"""Provider property for easy access."""
|
|
90
|
+
return self._provider
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def model(self) -> str:
|
|
94
|
+
"""Model name property for easy access."""
|
|
95
|
+
return self.value
|
|
96
|
+
|
|
97
|
+
def __str__(self):
|
|
98
|
+
return self.value
|
|
99
|
+
|
|
100
|
+
def __repr__(self):
|
|
101
|
+
return f"LanguageModelType('{self.original}' -> provider='{self._provider}', model='{self.value}')"
|
|
102
|
+
|
|
103
|
+
def __eq__(self, other):
|
|
104
|
+
if isinstance(other, LanguageModelType):
|
|
105
|
+
return self.value == other.value and self._provider == other._provider
|
|
106
|
+
elif isinstance(other, str):
|
|
107
|
+
return self.value == other or self.original == other
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def __hash__(self):
|
|
111
|
+
return hash((self._provider, self.value))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class TokenUsage:
|
|
116
|
+
"""Token usage information."""
|
|
117
|
+
prompt_tokens: int = 0
|
|
118
|
+
completion_tokens: int = 0
|
|
119
|
+
total_tokens: int = 0
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class ToolCall:
|
|
124
|
+
"""Represents a tool call made by the LLM."""
|
|
125
|
+
id: str
|
|
126
|
+
name: str
|
|
127
|
+
arguments: Dict[str, Any]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class ToolResult:
|
|
132
|
+
"""Represents the result of a tool call."""
|
|
133
|
+
tool_call_id: str
|
|
134
|
+
output: Optional[Any] = None
|
|
135
|
+
error: Optional[str] = None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class Message:
|
|
140
|
+
"""Represents a message in a conversation."""
|
|
141
|
+
role: Role
|
|
142
|
+
content: Union[str, List[Dict[str, Any]]]
|
|
143
|
+
name: Optional[str] = None
|
|
144
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
145
|
+
tool_call_id: Optional[str] = None
|
|
146
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
147
|
+
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class LanguageModelResponse:
|
|
152
|
+
"""Response from a language model."""
|
|
153
|
+
message: str
|
|
154
|
+
usage: TokenUsage
|
|
155
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
156
|
+
model: Optional[str] = None
|
|
157
|
+
finish_reason: Optional[str] = None
|
|
158
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class LLMError(Exception):
|
|
162
|
+
"""Base exception for LLM-related errors."""
|
|
163
|
+
|
|
164
|
+
def __init__(self, message: str, provider: str = None, model: str = None, **kwargs):
|
|
165
|
+
super().__init__(message)
|
|
166
|
+
self.message = message
|
|
167
|
+
self.provider = provider
|
|
168
|
+
self.model = model
|
|
169
|
+
self.metadata = kwargs
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class LanguageModel(ABC):
|
|
173
|
+
"""
|
|
174
|
+
Abstract base class for all language model providers.
|
|
175
|
+
|
|
176
|
+
Provides a unified interface for different LLM providers with support for:
|
|
177
|
+
- Text generation with tool calling
|
|
178
|
+
- Streaming responses
|
|
179
|
+
- Message format conversion
|
|
180
|
+
- Error handling
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
llm_model: LanguageModelType,
|
|
186
|
+
system_prompt: Optional[str] = None,
|
|
187
|
+
**kwargs
|
|
188
|
+
):
|
|
189
|
+
self.llm_model = llm_model
|
|
190
|
+
self.system_prompt = system_prompt
|
|
191
|
+
self.config = kwargs
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def provider_name(self) -> str:
|
|
195
|
+
"""Get the provider name."""
|
|
196
|
+
return self.llm_model.get_provider()
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def model_name(self) -> str:
|
|
200
|
+
"""Get the model name."""
|
|
201
|
+
return self.llm_model.value
|
|
202
|
+
|
|
203
|
+
@abstractmethod
|
|
204
|
+
async def generate(
|
|
205
|
+
self,
|
|
206
|
+
messages: List[Message],
|
|
207
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
208
|
+
max_tokens: int = 1024,
|
|
209
|
+
temperature: float = 0.7,
|
|
210
|
+
top_p: float = 1.0,
|
|
211
|
+
stream: bool = False,
|
|
212
|
+
**kwargs
|
|
213
|
+
) -> Union[LanguageModelResponse, AsyncIterator[LanguageModelResponse]]:
|
|
214
|
+
"""
|
|
215
|
+
Generate a response from the language model.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
messages: List of conversation messages
|
|
219
|
+
tools: Optional list of available tools
|
|
220
|
+
max_tokens: Maximum tokens to generate
|
|
221
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
222
|
+
top_p: Nucleus sampling parameter
|
|
223
|
+
stream: Whether to stream the response
|
|
224
|
+
**kwargs: Provider-specific parameters
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
LanguageModelResponse or async iterator for streaming
|
|
228
|
+
"""
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
@abstractmethod
|
|
232
|
+
def convert_messages_to_provider_format(self, messages: List[Message]) -> List[Dict[str, Any]]:
|
|
233
|
+
"""Convert internal message format to provider-specific format."""
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
@abstractmethod
|
|
237
|
+
def convert_tools_to_provider_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
238
|
+
"""Convert internal tool format to provider-specific format."""
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
def prepare_system_message(self, messages: List[Message]) -> List[Message]:
|
|
242
|
+
"""Prepare system message for the conversation."""
|
|
243
|
+
if not self.system_prompt:
|
|
244
|
+
return messages
|
|
245
|
+
|
|
246
|
+
# Check if system message already exists
|
|
247
|
+
if messages and messages[0].role == Role.SYSTEM:
|
|
248
|
+
# Merge with existing system message
|
|
249
|
+
existing_content = messages[0].content
|
|
250
|
+
if isinstance(existing_content, str):
|
|
251
|
+
combined_content = f"{self.system_prompt}\n\n{existing_content}"
|
|
252
|
+
else:
|
|
253
|
+
combined_content = self.system_prompt
|
|
254
|
+
|
|
255
|
+
messages[0].content = combined_content
|
|
256
|
+
return messages
|
|
257
|
+
else:
|
|
258
|
+
# Add new system message
|
|
259
|
+
system_message = Message(role=Role.SYSTEM, content=self.system_prompt)
|
|
260
|
+
return [system_message] + messages
|
|
261
|
+
|
|
262
|
+
def validate_messages(self, messages: List[Message]) -> None:
|
|
263
|
+
"""Validate message format and content."""
|
|
264
|
+
if not messages:
|
|
265
|
+
raise LLMError("No messages provided", provider=self.provider_name)
|
|
266
|
+
|
|
267
|
+
for i, message in enumerate(messages):
|
|
268
|
+
if not isinstance(message.role, Role):
|
|
269
|
+
raise LLMError(f"Invalid role at message {i}: {message.role}", provider=self.provider_name)
|
|
270
|
+
|
|
271
|
+
if not message.content and not message.tool_calls:
|
|
272
|
+
raise LLMError(f"Empty content at message {i}", provider=self.provider_name)
|
|
273
|
+
|
|
274
|
+
def extract_tool_calls_from_response(self, response: Any) -> List[ToolCall]:
|
|
275
|
+
"""Extract tool calls from provider response."""
|
|
276
|
+
# Default implementation - override in subclasses
|
|
277
|
+
return []
|
|
278
|
+
|
|
279
|
+
async def generate_with_retry(
|
|
280
|
+
self,
|
|
281
|
+
messages: List[Message],
|
|
282
|
+
max_retries: int = 3,
|
|
283
|
+
**kwargs
|
|
284
|
+
) -> LanguageModelResponse:
|
|
285
|
+
"""Generate response with automatic retry on failure."""
|
|
286
|
+
last_error = None
|
|
287
|
+
|
|
288
|
+
for attempt in range(max_retries + 1):
|
|
289
|
+
try:
|
|
290
|
+
result = await self.generate(messages, **kwargs)
|
|
291
|
+
if isinstance(result, AsyncIterator):
|
|
292
|
+
# Convert streaming to single response for retry logic
|
|
293
|
+
response_text = ""
|
|
294
|
+
async for chunk in result:
|
|
295
|
+
response_text += chunk.message
|
|
296
|
+
return LanguageModelResponse(
|
|
297
|
+
message=response_text,
|
|
298
|
+
usage=TokenUsage(),
|
|
299
|
+
model=self.model_name
|
|
300
|
+
)
|
|
301
|
+
return result
|
|
302
|
+
except Exception as e:
|
|
303
|
+
last_error = e
|
|
304
|
+
if attempt < max_retries:
|
|
305
|
+
# Exponential backoff
|
|
306
|
+
import asyncio
|
|
307
|
+
await asyncio.sleep(2 ** attempt)
|
|
308
|
+
continue
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
raise LLMError(
|
|
312
|
+
f"Failed after {max_retries + 1} attempts: {last_error}",
|
|
313
|
+
provider=self.provider_name,
|
|
314
|
+
model=self.model_name
|
|
315
|
+
) from last_error
|
agnt5/llm/google.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Gemini integration for AGNT5 SDK.
|
|
3
|
+
|
|
4
|
+
Provides integration with Google's Gemini models including proper message conversion,
|
|
5
|
+
tool calling, and streaming support.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
from .base import (
|
|
13
|
+
LanguageModel,
|
|
14
|
+
LanguageModelResponse,
|
|
15
|
+
LanguageModelType,
|
|
16
|
+
LLMError,
|
|
17
|
+
Message,
|
|
18
|
+
Role,
|
|
19
|
+
TokenUsage,
|
|
20
|
+
ToolCall,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import google.generativeai as genai
|
|
25
|
+
from google.generativeai.types import (
|
|
26
|
+
GenerateContentResponse,
|
|
27
|
+
Content,
|
|
28
|
+
Part,
|
|
29
|
+
FunctionCall,
|
|
30
|
+
FunctionResponse,
|
|
31
|
+
)
|
|
32
|
+
GOOGLE_AVAILABLE = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
GOOGLE_AVAILABLE = False
|
|
35
|
+
# Define placeholder types for type hints when Google SDK is not available
|
|
36
|
+
Content = Any
|
|
37
|
+
Part = Any
|
|
38
|
+
FunctionCall = Any
|
|
39
|
+
FunctionResponse = Any
|
|
40
|
+
GenerateContentResponse = Any
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GoogleError(LLMError):
|
|
44
|
+
"""Google Gemini-specific errors."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class GoogleLanguageModel(LanguageModel):
|
|
49
|
+
"""
|
|
50
|
+
Google Gemini language model implementation.
|
|
51
|
+
|
|
52
|
+
Supports all Gemini models with proper message conversion, tool calling,
|
|
53
|
+
and streaming capabilities.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
llm_model: LanguageModelType,
|
|
59
|
+
api_key: Optional[str] = None,
|
|
60
|
+
system_prompt: Optional[str] = None,
|
|
61
|
+
**kwargs
|
|
62
|
+
):
|
|
63
|
+
if not GOOGLE_AVAILABLE:
|
|
64
|
+
raise GoogleError("Google AI library not installed. Install with: pip install google-generativeai")
|
|
65
|
+
|
|
66
|
+
super().__init__(llm_model, system_prompt, **kwargs)
|
|
67
|
+
|
|
68
|
+
# Get API key
|
|
69
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
70
|
+
if not self.api_key:
|
|
71
|
+
raise GoogleError("Google API key required. Set GOOGLE_API_KEY or pass api_key parameter")
|
|
72
|
+
|
|
73
|
+
# Configure the client
|
|
74
|
+
genai.configure(api_key=self.api_key)
|
|
75
|
+
|
|
76
|
+
# Validate model is supported by Google
|
|
77
|
+
if not self.model_name.startswith("gemini"):
|
|
78
|
+
raise GoogleError(f"Model {self.model_name} is not a Google Gemini model")
|
|
79
|
+
|
|
80
|
+
# Initialize the model
|
|
81
|
+
self.model = genai.GenerativeModel(self.model_name)
|
|
82
|
+
|
|
83
|
+
async def generate(
|
|
84
|
+
self,
|
|
85
|
+
messages: List[Message],
|
|
86
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
87
|
+
max_tokens: int = 1024,
|
|
88
|
+
temperature: float = 0.7,
|
|
89
|
+
top_p: float = 1.0,
|
|
90
|
+
stream: bool = False,
|
|
91
|
+
**kwargs
|
|
92
|
+
) -> Union[LanguageModelResponse, AsyncIterator[LanguageModelResponse]]:
|
|
93
|
+
"""Generate response using Google Gemini."""
|
|
94
|
+
try:
|
|
95
|
+
# Validate and prepare messages
|
|
96
|
+
self.validate_messages(messages)
|
|
97
|
+
prepared_messages = self.prepare_system_message(messages)
|
|
98
|
+
|
|
99
|
+
# Convert to Gemini format
|
|
100
|
+
gemini_contents = self.convert_messages_to_provider_format(prepared_messages)
|
|
101
|
+
|
|
102
|
+
# Prepare generation config
|
|
103
|
+
generation_config = {
|
|
104
|
+
"max_output_tokens": max_tokens,
|
|
105
|
+
"temperature": temperature,
|
|
106
|
+
"top_p": top_p,
|
|
107
|
+
}
|
|
108
|
+
generation_config.update(kwargs)
|
|
109
|
+
|
|
110
|
+
# Prepare tools if provided
|
|
111
|
+
gemini_tools = None
|
|
112
|
+
if tools:
|
|
113
|
+
gemini_tools = self.convert_tools_to_provider_format(tools)
|
|
114
|
+
|
|
115
|
+
if stream:
|
|
116
|
+
return self._generate_stream(gemini_contents, generation_config, gemini_tools)
|
|
117
|
+
else:
|
|
118
|
+
return await self._generate_single(gemini_contents, generation_config, gemini_tools)
|
|
119
|
+
|
|
120
|
+
except Exception as e:
|
|
121
|
+
# Handle various Google AI exceptions
|
|
122
|
+
error_msg = str(e)
|
|
123
|
+
if "API_KEY" in error_msg.upper():
|
|
124
|
+
raise GoogleError(f"Google API authentication error: {e}", provider="google", model=self.model_name) from e
|
|
125
|
+
elif "QUOTA" in error_msg.upper() or "RATE_LIMIT" in error_msg.upper():
|
|
126
|
+
raise GoogleError(f"Google API quota/rate limit error: {e}", provider="google", model=self.model_name) from e
|
|
127
|
+
else:
|
|
128
|
+
raise GoogleError(f"Google API error: {e}", provider="google", model=self.model_name) from e
|
|
129
|
+
|
|
130
|
+
async def _generate_single(
|
|
131
|
+
self,
|
|
132
|
+
contents: List[Content],
|
|
133
|
+
generation_config: Dict[str, Any],
|
|
134
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
|
135
|
+
) -> LanguageModelResponse:
|
|
136
|
+
"""Generate a single response."""
|
|
137
|
+
try:
|
|
138
|
+
# Create generation arguments
|
|
139
|
+
generate_kwargs = {
|
|
140
|
+
"contents": contents,
|
|
141
|
+
"generation_config": generation_config,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
if tools:
|
|
145
|
+
generate_kwargs["tools"] = tools
|
|
146
|
+
|
|
147
|
+
# Generate content
|
|
148
|
+
response = self.model.generate_content(**generate_kwargs)
|
|
149
|
+
|
|
150
|
+
# Extract text content
|
|
151
|
+
response_text = ""
|
|
152
|
+
tool_calls = []
|
|
153
|
+
|
|
154
|
+
if response.candidates:
|
|
155
|
+
candidate = response.candidates[0]
|
|
156
|
+
for part in candidate.content.parts:
|
|
157
|
+
if hasattr(part, 'text') and part.text:
|
|
158
|
+
response_text += part.text
|
|
159
|
+
elif hasattr(part, 'function_call') and part.function_call:
|
|
160
|
+
# Extract function call
|
|
161
|
+
func_call = part.function_call
|
|
162
|
+
tool_calls.append(ToolCall(
|
|
163
|
+
id=f"call_{hash(func_call.name)}", # Generate ID since Gemini doesn't provide one
|
|
164
|
+
name=func_call.name,
|
|
165
|
+
arguments=dict(func_call.args) if func_call.args else {}
|
|
166
|
+
))
|
|
167
|
+
|
|
168
|
+
# Calculate token usage (Gemini provides usage in response)
|
|
169
|
+
usage = TokenUsage()
|
|
170
|
+
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
|
171
|
+
usage = TokenUsage(
|
|
172
|
+
prompt_tokens=response.usage_metadata.prompt_token_count,
|
|
173
|
+
completion_tokens=response.usage_metadata.candidates_token_count,
|
|
174
|
+
total_tokens=response.usage_metadata.total_token_count
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return LanguageModelResponse(
|
|
178
|
+
message=response_text,
|
|
179
|
+
usage=usage,
|
|
180
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
181
|
+
model=self.model_name,
|
|
182
|
+
finish_reason=getattr(response.candidates[0], 'finish_reason', None) if response.candidates else None,
|
|
183
|
+
metadata={
|
|
184
|
+
"safety_ratings": getattr(response.candidates[0], 'safety_ratings', []) if response.candidates else []
|
|
185
|
+
}
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
raise GoogleError(f"Error generating single response: {e}", provider="google", model=self.model_name) from e
|
|
190
|
+
|
|
191
|
+
async def _generate_stream(
|
|
192
|
+
self,
|
|
193
|
+
contents: List[Content],
|
|
194
|
+
generation_config: Dict[str, Any],
|
|
195
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
|
196
|
+
) -> AsyncIterator[LanguageModelResponse]:
|
|
197
|
+
"""Generate streaming response."""
|
|
198
|
+
try:
|
|
199
|
+
# Create generation arguments
|
|
200
|
+
generate_kwargs = {
|
|
201
|
+
"contents": contents,
|
|
202
|
+
"generation_config": generation_config,
|
|
203
|
+
"stream": True,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
if tools:
|
|
207
|
+
generate_kwargs["tools"] = tools
|
|
208
|
+
|
|
209
|
+
# Generate streaming content
|
|
210
|
+
response_stream = self.model.generate_content(**generate_kwargs)
|
|
211
|
+
|
|
212
|
+
for chunk in response_stream:
|
|
213
|
+
if chunk.candidates:
|
|
214
|
+
candidate = chunk.candidates[0]
|
|
215
|
+
for part in candidate.content.parts:
|
|
216
|
+
if hasattr(part, 'text') and part.text:
|
|
217
|
+
yield LanguageModelResponse(
|
|
218
|
+
message=part.text,
|
|
219
|
+
usage=TokenUsage(),
|
|
220
|
+
model=self.model_name
|
|
221
|
+
)
|
|
222
|
+
elif hasattr(part, 'function_call') and part.function_call:
|
|
223
|
+
# Handle function calls in streaming
|
|
224
|
+
func_call = part.function_call
|
|
225
|
+
yield LanguageModelResponse(
|
|
226
|
+
message="",
|
|
227
|
+
usage=TokenUsage(),
|
|
228
|
+
tool_calls=[ToolCall(
|
|
229
|
+
id=f"call_{hash(func_call.name)}",
|
|
230
|
+
name=func_call.name,
|
|
231
|
+
arguments=dict(func_call.args) if func_call.args else {}
|
|
232
|
+
)],
|
|
233
|
+
model=self.model_name
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
raise GoogleError(f"Error generating streaming response: {e}", provider="google", model=self.model_name) from e
|
|
238
|
+
|
|
239
|
+
def convert_messages_to_provider_format(self, messages: List[Message]) -> List[Content]:
|
|
240
|
+
"""Convert internal messages to Gemini Content format."""
|
|
241
|
+
contents = []
|
|
242
|
+
system_instruction = None
|
|
243
|
+
|
|
244
|
+
for message in messages:
|
|
245
|
+
# Handle system messages separately
|
|
246
|
+
if message.role == Role.SYSTEM:
|
|
247
|
+
system_instruction = message.content
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
# Convert role
|
|
251
|
+
if message.role == Role.USER:
|
|
252
|
+
role = "user"
|
|
253
|
+
elif message.role == Role.ASSISTANT:
|
|
254
|
+
role = "model" # Gemini uses "model" instead of "assistant"
|
|
255
|
+
elif message.role == Role.TOOL:
|
|
256
|
+
# Tool results are handled as function responses
|
|
257
|
+
role = "function"
|
|
258
|
+
else:
|
|
259
|
+
continue # Skip unsupported roles
|
|
260
|
+
|
|
261
|
+
# Prepare content parts
|
|
262
|
+
parts = []
|
|
263
|
+
|
|
264
|
+
if isinstance(message.content, str):
|
|
265
|
+
if message.content: # Only add non-empty content
|
|
266
|
+
parts.append(Part(text=message.content))
|
|
267
|
+
elif isinstance(message.content, list):
|
|
268
|
+
parts.extend(self._convert_content_blocks(message.content))
|
|
269
|
+
|
|
270
|
+
# Handle tool calls for assistant messages
|
|
271
|
+
if message.tool_calls and message.role == Role.ASSISTANT:
|
|
272
|
+
for tool_call in message.tool_calls:
|
|
273
|
+
parts.append(Part(
|
|
274
|
+
function_call=FunctionCall(
|
|
275
|
+
name=tool_call.name,
|
|
276
|
+
args=tool_call.arguments
|
|
277
|
+
)
|
|
278
|
+
))
|
|
279
|
+
|
|
280
|
+
# Handle tool results
|
|
281
|
+
if message.tool_call_id and message.role == Role.TOOL:
|
|
282
|
+
# Tool results in Gemini are handled as function responses
|
|
283
|
+
parts.append(Part(
|
|
284
|
+
function_response=FunctionResponse(
|
|
285
|
+
name=f"function_{message.tool_call_id}", # Reconstruct function name
|
|
286
|
+
response={"result": message.content}
|
|
287
|
+
)
|
|
288
|
+
))
|
|
289
|
+
|
|
290
|
+
if parts: # Only add content if there are parts
|
|
291
|
+
contents.append(Content(role=role, parts=parts))
|
|
292
|
+
|
|
293
|
+
# Add system instruction to the model if present
|
|
294
|
+
if system_instruction and contents:
|
|
295
|
+
# Prepend system instruction as user message (Gemini pattern)
|
|
296
|
+
system_content = Content(
|
|
297
|
+
role="user",
|
|
298
|
+
parts=[Part(text=f"System: {system_instruction}")]
|
|
299
|
+
)
|
|
300
|
+
contents.insert(0, system_content)
|
|
301
|
+
|
|
302
|
+
return contents
|
|
303
|
+
|
|
304
|
+
def _convert_content_blocks(self, content_blocks: List[Dict[str, Any]]) -> List[Part]:
|
|
305
|
+
"""Convert content blocks to Gemini Part format."""
|
|
306
|
+
parts = []
|
|
307
|
+
|
|
308
|
+
for block in content_blocks:
|
|
309
|
+
if isinstance(block, str):
|
|
310
|
+
parts.append(Part(text=block))
|
|
311
|
+
elif isinstance(block, dict):
|
|
312
|
+
block_type = block.get("type", "text")
|
|
313
|
+
|
|
314
|
+
if block_type == "text":
|
|
315
|
+
text_content = block.get("text", str(block))
|
|
316
|
+
if text_content:
|
|
317
|
+
parts.append(Part(text=text_content))
|
|
318
|
+
elif block_type == "image" or block_type == "image_url":
|
|
319
|
+
# Handle image content (Gemini supports images)
|
|
320
|
+
# This would require additional image processing
|
|
321
|
+
# For now, convert to text description
|
|
322
|
+
image_desc = block.get("alt_text", "Image content")
|
|
323
|
+
parts.append(Part(text=f"[Image: {image_desc}]"))
|
|
324
|
+
else:
|
|
325
|
+
# Convert unknown blocks to text
|
|
326
|
+
parts.append(Part(text=str(block)))
|
|
327
|
+
else:
|
|
328
|
+
parts.append(Part(text=str(block)))
|
|
329
|
+
|
|
330
|
+
return parts
|
|
331
|
+
|
|
332
|
+
def convert_tools_to_provider_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
333
|
+
"""Convert tools to Gemini function format."""
|
|
334
|
+
gemini_functions = []
|
|
335
|
+
|
|
336
|
+
for tool in tools:
|
|
337
|
+
if "function" in tool:
|
|
338
|
+
# OpenAI-style tool format
|
|
339
|
+
func = tool["function"]
|
|
340
|
+
gemini_function = {
|
|
341
|
+
"name": func["name"],
|
|
342
|
+
"description": func.get("description", ""),
|
|
343
|
+
"parameters": func.get("parameters", {})
|
|
344
|
+
}
|
|
345
|
+
else:
|
|
346
|
+
# Direct format or simple format
|
|
347
|
+
gemini_function = {
|
|
348
|
+
"name": tool.get("name", "unknown"),
|
|
349
|
+
"description": tool.get("description", ""),
|
|
350
|
+
"parameters": tool.get("parameters", tool.get("input_schema", {}))
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
gemini_functions.append(gemini_function)
|
|
354
|
+
|
|
355
|
+
# Wrap functions in the expected Gemini tools format
|
|
356
|
+
return [{"function_declarations": gemini_functions}]
|
|
357
|
+
|
|
358
|
+
def extract_tool_calls_from_response(self, response: Any) -> List[ToolCall]:
|
|
359
|
+
"""Extract tool calls from Gemini response."""
|
|
360
|
+
tool_calls = []
|
|
361
|
+
|
|
362
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
363
|
+
candidate = response.candidates[0]
|
|
364
|
+
for part in candidate.content.parts:
|
|
365
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
366
|
+
func_call = part.function_call
|
|
367
|
+
tool_calls.append(ToolCall(
|
|
368
|
+
id=f"call_{hash(func_call.name)}",
|
|
369
|
+
name=func_call.name,
|
|
370
|
+
arguments=dict(func_call.args) if func_call.args else {}
|
|
371
|
+
))
|
|
372
|
+
|
|
373
|
+
return tool_calls
|