voxagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- voxagent/__init__.py +143 -0
- voxagent/_version.py +5 -0
- voxagent/agent/__init__.py +32 -0
- voxagent/agent/abort.py +178 -0
- voxagent/agent/core.py +902 -0
- voxagent/code/__init__.py +9 -0
- voxagent/mcp/__init__.py +16 -0
- voxagent/mcp/manager.py +188 -0
- voxagent/mcp/tool.py +152 -0
- voxagent/providers/__init__.py +110 -0
- voxagent/providers/anthropic.py +498 -0
- voxagent/providers/augment.py +293 -0
- voxagent/providers/auth.py +116 -0
- voxagent/providers/base.py +268 -0
- voxagent/providers/chatgpt.py +415 -0
- voxagent/providers/claudecode.py +162 -0
- voxagent/providers/cli_base.py +265 -0
- voxagent/providers/codex.py +183 -0
- voxagent/providers/failover.py +90 -0
- voxagent/providers/google.py +532 -0
- voxagent/providers/groq.py +96 -0
- voxagent/providers/ollama.py +425 -0
- voxagent/providers/openai.py +435 -0
- voxagent/providers/registry.py +175 -0
- voxagent/py.typed +1 -0
- voxagent/security/__init__.py +14 -0
- voxagent/security/events.py +75 -0
- voxagent/security/filter.py +169 -0
- voxagent/security/registry.py +87 -0
- voxagent/session/__init__.py +39 -0
- voxagent/session/compaction.py +237 -0
- voxagent/session/lock.py +103 -0
- voxagent/session/model.py +109 -0
- voxagent/session/storage.py +184 -0
- voxagent/streaming/__init__.py +52 -0
- voxagent/streaming/emitter.py +286 -0
- voxagent/streaming/events.py +255 -0
- voxagent/subagent/__init__.py +20 -0
- voxagent/subagent/context.py +124 -0
- voxagent/subagent/definition.py +172 -0
- voxagent/tools/__init__.py +32 -0
- voxagent/tools/context.py +50 -0
- voxagent/tools/decorator.py +175 -0
- voxagent/tools/definition.py +131 -0
- voxagent/tools/executor.py +109 -0
- voxagent/tools/policy.py +89 -0
- voxagent/tools/registry.py +89 -0
- voxagent/types/__init__.py +46 -0
- voxagent/types/messages.py +134 -0
- voxagent/types/run.py +176 -0
- voxagent-0.1.0.dist-info/METADATA +186 -0
- voxagent-0.1.0.dist-info/RECORD +53 -0
- voxagent-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
"""Google (Gemini) provider implementation.
|
|
2
|
+
|
|
3
|
+
This module implements the GoogleProvider for Gemini models,
|
|
4
|
+
supporting streaming, tool use, and multimodal content.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import uuid
|
|
10
|
+
from collections.abc import AsyncIterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from voxagent.providers.base import (
|
|
16
|
+
AbortSignal,
|
|
17
|
+
BaseProvider,
|
|
18
|
+
ErrorChunk,
|
|
19
|
+
MessageEndChunk,
|
|
20
|
+
StreamChunk,
|
|
21
|
+
TextDeltaChunk,
|
|
22
|
+
ToolUseChunk,
|
|
23
|
+
)
|
|
24
|
+
from voxagent.types import Message, ToolCall
|
|
25
|
+
from voxagent.types.messages import ToolResultBlock
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GoogleProvider(BaseProvider):
|
|
29
|
+
"""Provider for Google Gemini models.
|
|
30
|
+
|
|
31
|
+
Supports Gemini 2.0 Flash, Gemini 1.5 Pro, Gemini 1.5 Flash, and other models.
|
|
32
|
+
Implements streaming, tool use, and large context windows.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
ENV_KEY = "GOOGLE_API_KEY"
|
|
36
|
+
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
|
37
|
+
|
|
38
|
+
SUPPORTED_MODELS = [
|
|
39
|
+
"gemini-2.0-flash",
|
|
40
|
+
"gemini-2.0-flash-thinking",
|
|
41
|
+
"gemini-1.5-pro",
|
|
42
|
+
"gemini-1.5-flash",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
api_key: str | None = None,
|
|
48
|
+
base_url: str | None = None,
|
|
49
|
+
model: str = "gemini-2.0-flash",
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initialize the Google provider.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
api_key: API key for authentication. Falls back to GOOGLE_API_KEY,
|
|
56
|
+
then GEMINI_API_KEY env vars.
|
|
57
|
+
base_url: Optional custom base URL for API requests.
|
|
58
|
+
model: Model to use (default: gemini-2.0-flash).
|
|
59
|
+
**kwargs: Additional provider-specific arguments.
|
|
60
|
+
"""
|
|
61
|
+
super().__init__(api_key=api_key, base_url=base_url, **kwargs)
|
|
62
|
+
self._model = model
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def api_key(self) -> str | None:
|
|
66
|
+
"""Get API key from constructor or environment variables."""
|
|
67
|
+
if self._api_key is not None:
|
|
68
|
+
return self._api_key
|
|
69
|
+
# Check GOOGLE_API_KEY first, then GEMINI_API_KEY
|
|
70
|
+
return os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def name(self) -> str:
|
|
74
|
+
"""Get the provider name."""
|
|
75
|
+
return "google"
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def model(self) -> str:
|
|
79
|
+
"""Get the current model."""
|
|
80
|
+
return self._model
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def models(self) -> list[str]:
|
|
84
|
+
"""Get the list of supported model names."""
|
|
85
|
+
return self.SUPPORTED_MODELS.copy()
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def supports_tools(self) -> bool:
|
|
89
|
+
"""Check if the provider supports tool/function calling."""
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def supports_streaming(self) -> bool:
|
|
94
|
+
"""Check if the provider supports streaming responses."""
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def context_limit(self) -> int:
|
|
99
|
+
"""Get the maximum context length in tokens."""
|
|
100
|
+
return 1000000
|
|
101
|
+
|
|
102
|
+
def _get_base_url(self) -> str:
|
|
103
|
+
"""Get the base URL for API requests."""
|
|
104
|
+
return self._base_url or self.DEFAULT_BASE_URL
|
|
105
|
+
|
|
106
|
+
def _get_stream_endpoint(self) -> str:
|
|
107
|
+
"""Get the streaming endpoint URL."""
|
|
108
|
+
return f"{self._get_base_url()}/models/{self._model}:streamGenerateContent"
|
|
109
|
+
|
|
110
|
+
def _get_complete_endpoint(self) -> str:
|
|
111
|
+
"""Get the non-streaming endpoint URL."""
|
|
112
|
+
return f"{self._get_base_url()}/models/{self._model}:generateContent"
|
|
113
|
+
|
|
114
|
+
def _get_request_url(self, action: str) -> str:
|
|
115
|
+
"""Get the full request URL with API key."""
|
|
116
|
+
base = self._get_base_url()
|
|
117
|
+
return f"{base}/models/{self._model}:{action}?key={self.api_key}"
|
|
118
|
+
|
|
119
|
+
def _convert_messages_to_gemini(
|
|
120
|
+
self, messages: list[Message]
|
|
121
|
+
) -> list[dict[str, Any]]:
|
|
122
|
+
"""Convert voxagent Messages to Gemini contents format.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
messages: List of voxagent Messages.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of Gemini-formatted content dicts.
|
|
129
|
+
"""
|
|
130
|
+
contents: list[dict[str, Any]] = []
|
|
131
|
+
|
|
132
|
+
for msg in messages:
|
|
133
|
+
if msg.role == "system":
|
|
134
|
+
# System messages are handled via system_instruction
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Map role: user -> user, assistant -> model
|
|
138
|
+
role = "model" if msg.role == "assistant" else "user"
|
|
139
|
+
|
|
140
|
+
parts: list[dict[str, Any]] = []
|
|
141
|
+
|
|
142
|
+
# Handle tool calls in assistant messages
|
|
143
|
+
if msg.tool_calls:
|
|
144
|
+
for tc in msg.tool_calls:
|
|
145
|
+
parts.append({
|
|
146
|
+
"functionCall": {
|
|
147
|
+
"name": tc.name,
|
|
148
|
+
"args": tc.params,
|
|
149
|
+
}
|
|
150
|
+
})
|
|
151
|
+
elif isinstance(msg.content, str):
|
|
152
|
+
if msg.content:
|
|
153
|
+
parts.append({"text": msg.content})
|
|
154
|
+
elif isinstance(msg.content, list):
|
|
155
|
+
# Handle content blocks (including tool results)
|
|
156
|
+
for block in msg.content:
|
|
157
|
+
# Handle ToolResultBlock Pydantic model
|
|
158
|
+
if isinstance(block, ToolResultBlock):
|
|
159
|
+
# For Gemini, tool results need functionResponse format
|
|
160
|
+
tool_name = block.tool_name or "unknown"
|
|
161
|
+
parts.append({
|
|
162
|
+
"functionResponse": {
|
|
163
|
+
"name": tool_name,
|
|
164
|
+
"response": {
|
|
165
|
+
"result": block.content,
|
|
166
|
+
"is_error": block.is_error,
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
})
|
|
170
|
+
elif isinstance(block, dict):
|
|
171
|
+
# Handle tool_result dict (fallback)
|
|
172
|
+
if block.get("type") == "tool_result":
|
|
173
|
+
tool_name = block.get("tool_name", "unknown")
|
|
174
|
+
parts.append({
|
|
175
|
+
"functionResponse": {
|
|
176
|
+
"name": tool_name,
|
|
177
|
+
"response": {
|
|
178
|
+
"result": block.get("content", ""),
|
|
179
|
+
"is_error": block.get("is_error", False),
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
})
|
|
183
|
+
elif "text" in block:
|
|
184
|
+
parts.append({"text": block["text"]})
|
|
185
|
+
elif hasattr(block, "text"):
|
|
186
|
+
parts.append({"text": block.text})
|
|
187
|
+
|
|
188
|
+
if parts:
|
|
189
|
+
contents.append({"role": role, "parts": parts})
|
|
190
|
+
|
|
191
|
+
return contents
|
|
192
|
+
|
|
193
|
+
def _convert_gemini_response_to_message(
|
|
194
|
+
self, response: dict[str, Any]
|
|
195
|
+
) -> Message:
|
|
196
|
+
"""Convert Gemini API response to voxagent Message.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
response: Gemini API response dict.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
A voxagent Message.
|
|
203
|
+
"""
|
|
204
|
+
candidates = response.get("candidates", [])
|
|
205
|
+
if not candidates:
|
|
206
|
+
return Message(role="assistant", content="")
|
|
207
|
+
|
|
208
|
+
content_obj = candidates[0].get("content", {})
|
|
209
|
+
parts = content_obj.get("parts", [])
|
|
210
|
+
|
|
211
|
+
text_content = ""
|
|
212
|
+
tool_calls: list[ToolCall] = []
|
|
213
|
+
|
|
214
|
+
for part in parts:
|
|
215
|
+
if "text" in part:
|
|
216
|
+
text_content += part["text"]
|
|
217
|
+
elif "functionCall" in part:
|
|
218
|
+
fc = part["functionCall"]
|
|
219
|
+
tool_calls.append(ToolCall(
|
|
220
|
+
id=str(uuid.uuid4()),
|
|
221
|
+
name=fc.get("name", ""),
|
|
222
|
+
params=fc.get("args", {}),
|
|
223
|
+
))
|
|
224
|
+
|
|
225
|
+
return Message(
|
|
226
|
+
role="assistant",
|
|
227
|
+
content=text_content,
|
|
228
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _sanitize_schema_for_gemini(self, schema: dict[str, Any]) -> dict[str, Any]:
|
|
232
|
+
"""Sanitize a JSON schema for Gemini API compatibility.
|
|
233
|
+
|
|
234
|
+
Google's Gemini API has specific requirements:
|
|
235
|
+
- Enum values must be strings
|
|
236
|
+
- Enum is only allowed for STRING type properties
|
|
237
|
+
|
|
238
|
+
This method recursively sanitizes the schema.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
schema: A JSON schema dict.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Sanitized schema compatible with Gemini API.
|
|
245
|
+
"""
|
|
246
|
+
if not isinstance(schema, dict):
|
|
247
|
+
return schema
|
|
248
|
+
|
|
249
|
+
result = {}
|
|
250
|
+
for key, value in schema.items():
|
|
251
|
+
if key == "enum" and isinstance(value, list):
|
|
252
|
+
# Convert all enum values to strings
|
|
253
|
+
result[key] = [str(v) for v in value]
|
|
254
|
+
elif isinstance(value, dict):
|
|
255
|
+
result[key] = self._sanitize_schema_for_gemini(value)
|
|
256
|
+
elif isinstance(value, list):
|
|
257
|
+
result[key] = [
|
|
258
|
+
self._sanitize_schema_for_gemini(item) if isinstance(item, dict) else item
|
|
259
|
+
for item in value
|
|
260
|
+
]
|
|
261
|
+
else:
|
|
262
|
+
result[key] = value
|
|
263
|
+
|
|
264
|
+
# If this schema has an enum, ensure type is "string" (Gemini requirement)
|
|
265
|
+
if "enum" in result:
|
|
266
|
+
result["type"] = "string"
|
|
267
|
+
|
|
268
|
+
return result
|
|
269
|
+
|
|
270
|
+
def _convert_tools_to_gemini(
|
|
271
|
+
self, tools: list[dict[str, Any]]
|
|
272
|
+
) -> list[dict[str, Any]]:
|
|
273
|
+
"""Convert tool definitions to Gemini format.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
tools: List of tool definitions in OpenAI format.
|
|
277
|
+
Each tool has structure: {"type": "function", "function": {...}}
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Gemini-formatted tool declarations.
|
|
281
|
+
"""
|
|
282
|
+
function_declarations = []
|
|
283
|
+
for tool in tools:
|
|
284
|
+
# Handle OpenAI format: {"type": "function", "function": {...}}
|
|
285
|
+
if tool.get("type") == "function" and "function" in tool:
|
|
286
|
+
func_info = tool["function"]
|
|
287
|
+
func_decl: dict[str, Any] = {
|
|
288
|
+
"name": func_info.get("name", ""),
|
|
289
|
+
"description": func_info.get("description", ""),
|
|
290
|
+
}
|
|
291
|
+
if "parameters" in func_info:
|
|
292
|
+
# Sanitize parameters schema for Gemini compatibility
|
|
293
|
+
func_decl["parameters"] = self._sanitize_schema_for_gemini(
|
|
294
|
+
func_info["parameters"]
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
# Fallback: assume flat structure
|
|
298
|
+
func_decl = {
|
|
299
|
+
"name": tool.get("name", ""),
|
|
300
|
+
"description": tool.get("description", ""),
|
|
301
|
+
}
|
|
302
|
+
if "parameters" in tool:
|
|
303
|
+
func_decl["parameters"] = self._sanitize_schema_for_gemini(
|
|
304
|
+
tool["parameters"]
|
|
305
|
+
)
|
|
306
|
+
function_declarations.append(func_decl)
|
|
307
|
+
|
|
308
|
+
return [{"functionDeclarations": function_declarations}]
|
|
309
|
+
|
|
310
|
+
def _build_request_body(
|
|
311
|
+
self,
|
|
312
|
+
messages: list[Message],
|
|
313
|
+
system: str | None = None,
|
|
314
|
+
tools: list[Any] | None = None,
|
|
315
|
+
) -> dict[str, Any]:
|
|
316
|
+
"""Build the request body for Gemini API.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
messages: The conversation messages.
|
|
320
|
+
system: Optional system prompt.
|
|
321
|
+
tools: Optional list of tool definitions.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Request body dict.
|
|
325
|
+
"""
|
|
326
|
+
body: dict[str, Any] = {
|
|
327
|
+
"contents": self._convert_messages_to_gemini(messages),
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
if system:
|
|
331
|
+
body["system_instruction"] = {"parts": [{"text": system}]}
|
|
332
|
+
|
|
333
|
+
if tools:
|
|
334
|
+
body["tools"] = self._convert_tools_to_gemini(tools)
|
|
335
|
+
|
|
336
|
+
return body
|
|
337
|
+
|
|
338
|
+
async def _make_streaming_request(
|
|
339
|
+
self,
|
|
340
|
+
messages: list[Message],
|
|
341
|
+
system: str | None = None,
|
|
342
|
+
tools: list[Any] | None = None,
|
|
343
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
344
|
+
"""Make a streaming request to the Gemini API.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
messages: The conversation messages.
|
|
348
|
+
system: Optional system prompt.
|
|
349
|
+
tools: Optional list of tool definitions.
|
|
350
|
+
|
|
351
|
+
Yields:
|
|
352
|
+
Parsed JSON response chunks.
|
|
353
|
+
"""
|
|
354
|
+
body = self._build_request_body(messages, system=system, tools=tools)
|
|
355
|
+
url = f"{self._get_stream_endpoint()}?key={self.api_key}&alt=sse"
|
|
356
|
+
|
|
357
|
+
async with httpx.AsyncClient() as client:
|
|
358
|
+
async with client.stream(
|
|
359
|
+
"POST",
|
|
360
|
+
url,
|
|
361
|
+
json=body,
|
|
362
|
+
headers={"Content-Type": "application/json"},
|
|
363
|
+
timeout=120.0,
|
|
364
|
+
) as response:
|
|
365
|
+
if response.status_code >= 400:
|
|
366
|
+
# Read the full error response for better error messages
|
|
367
|
+
error_body = await response.aread()
|
|
368
|
+
try:
|
|
369
|
+
error_json = json.loads(error_body)
|
|
370
|
+
error_msg = error_json.get("error", {}).get("message", error_body.decode())
|
|
371
|
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
372
|
+
error_msg = error_body.decode() if error_body else "Unknown error"
|
|
373
|
+
raise httpx.HTTPStatusError(
|
|
374
|
+
f"Google API error: {error_msg}",
|
|
375
|
+
request=response.request,
|
|
376
|
+
response=response,
|
|
377
|
+
)
|
|
378
|
+
async for line in response.aiter_lines():
|
|
379
|
+
line = line.strip()
|
|
380
|
+
if not line:
|
|
381
|
+
continue
|
|
382
|
+
if line.startswith("data:"):
|
|
383
|
+
data_str = line[5:].strip()
|
|
384
|
+
if data_str:
|
|
385
|
+
try:
|
|
386
|
+
yield json.loads(data_str)
|
|
387
|
+
except json.JSONDecodeError:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
async def stream(
|
|
391
|
+
self,
|
|
392
|
+
messages: list[Message],
|
|
393
|
+
system: str | None = None,
|
|
394
|
+
tools: list[Any] | None = None,
|
|
395
|
+
abort_signal: AbortSignal | None = None,
|
|
396
|
+
) -> AsyncIterator[StreamChunk]:
|
|
397
|
+
"""Stream a response from the Gemini API.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
messages: The conversation messages.
|
|
401
|
+
system: Optional system prompt.
|
|
402
|
+
tools: Optional list of tool definitions.
|
|
403
|
+
abort_signal: Optional signal to abort the stream.
|
|
404
|
+
|
|
405
|
+
Yields:
|
|
406
|
+
StreamChunk objects containing response data.
|
|
407
|
+
"""
|
|
408
|
+
try:
|
|
409
|
+
async for chunk in self._make_streaming_request(
|
|
410
|
+
messages, system=system, tools=tools
|
|
411
|
+
):
|
|
412
|
+
if abort_signal and abort_signal.aborted:
|
|
413
|
+
break
|
|
414
|
+
|
|
415
|
+
candidates = chunk.get("candidates", [])
|
|
416
|
+
if not candidates:
|
|
417
|
+
continue
|
|
418
|
+
|
|
419
|
+
content = candidates[0].get("content", {})
|
|
420
|
+
parts = content.get("parts", [])
|
|
421
|
+
|
|
422
|
+
for part in parts:
|
|
423
|
+
if "text" in part:
|
|
424
|
+
yield TextDeltaChunk(delta=part["text"])
|
|
425
|
+
elif "functionCall" in part:
|
|
426
|
+
fc = part["functionCall"]
|
|
427
|
+
yield ToolUseChunk(
|
|
428
|
+
tool_call=ToolCall(
|
|
429
|
+
id=str(uuid.uuid4()),
|
|
430
|
+
name=fc.get("name", ""),
|
|
431
|
+
params=fc.get("args", {}),
|
|
432
|
+
)
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
yield MessageEndChunk()
|
|
436
|
+
|
|
437
|
+
except Exception as e:
|
|
438
|
+
yield ErrorChunk(error=str(e))
|
|
439
|
+
|
|
440
|
+
async def _make_request(
|
|
441
|
+
self,
|
|
442
|
+
messages: list[Message],
|
|
443
|
+
system: str | None = None,
|
|
444
|
+
tools: list[Any] | None = None,
|
|
445
|
+
) -> dict[str, Any]:
|
|
446
|
+
"""Make a non-streaming request to the Gemini API.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
messages: The conversation messages.
|
|
450
|
+
system: Optional system prompt.
|
|
451
|
+
tools: Optional list of tool definitions.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
The response JSON dict.
|
|
455
|
+
"""
|
|
456
|
+
body = self._build_request_body(messages, system=system, tools=tools)
|
|
457
|
+
url = f"{self._get_complete_endpoint()}?key={self.api_key}"
|
|
458
|
+
|
|
459
|
+
async with httpx.AsyncClient() as client:
|
|
460
|
+
response = await client.post(
|
|
461
|
+
url,
|
|
462
|
+
json=body,
|
|
463
|
+
headers={"Content-Type": "application/json"},
|
|
464
|
+
timeout=120.0,
|
|
465
|
+
)
|
|
466
|
+
response.raise_for_status()
|
|
467
|
+
return response.json()
|
|
468
|
+
|
|
469
|
+
async def complete(
|
|
470
|
+
self,
|
|
471
|
+
messages: list[Message],
|
|
472
|
+
system: str | None = None,
|
|
473
|
+
tools: list[Any] | None = None,
|
|
474
|
+
) -> Message:
|
|
475
|
+
"""Get a complete response from the Gemini API.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
messages: The conversation messages.
|
|
479
|
+
system: Optional system prompt.
|
|
480
|
+
tools: Optional list of tool definitions.
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
The assistant's response message.
|
|
484
|
+
"""
|
|
485
|
+
response = await self._make_request(messages, system=system, tools=tools)
|
|
486
|
+
return self._convert_gemini_response_to_message(response)
|
|
487
|
+
|
|
488
|
+
def count_tokens(
|
|
489
|
+
self,
|
|
490
|
+
messages: list[Message],
|
|
491
|
+
system: str | None = None,
|
|
492
|
+
) -> int:
|
|
493
|
+
"""Count tokens in the messages.
|
|
494
|
+
|
|
495
|
+
Uses a simple estimation based on character count.
|
|
496
|
+
For more accurate counting, use the Gemini token counting API.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
messages: The conversation messages.
|
|
500
|
+
system: Optional system prompt.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Estimated token count.
|
|
504
|
+
"""
|
|
505
|
+
# Approximate token counting (roughly 4 chars per token for English)
|
|
506
|
+
char_count = 0
|
|
507
|
+
|
|
508
|
+
if system:
|
|
509
|
+
char_count += len(system)
|
|
510
|
+
|
|
511
|
+
for msg in messages:
|
|
512
|
+
if isinstance(msg.content, str):
|
|
513
|
+
char_count += len(msg.content)
|
|
514
|
+
else:
|
|
515
|
+
for block in msg.content:
|
|
516
|
+
if hasattr(block, "text"):
|
|
517
|
+
char_count += len(block.text)
|
|
518
|
+
elif hasattr(block, "content"):
|
|
519
|
+
char_count += len(block.content)
|
|
520
|
+
|
|
521
|
+
# Add overhead for role and structure
|
|
522
|
+
char_count += 10
|
|
523
|
+
|
|
524
|
+
if msg.tool_calls:
|
|
525
|
+
for tc in msg.tool_calls:
|
|
526
|
+
char_count += len(tc.name) + len(json.dumps(tc.params))
|
|
527
|
+
|
|
528
|
+
# Rough estimate: 4 characters per token
|
|
529
|
+
return max(1, char_count // 4)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
__all__ = ["GoogleProvider"]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Groq provider implementation.
|
|
2
|
+
|
|
3
|
+
This module implements the Groq provider for chat completions,
|
|
4
|
+
using Groq's OpenAI-compatible API format.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from voxagent.providers.openai import OpenAIProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GroqProvider(OpenAIProvider):
|
|
13
|
+
"""Groq chat completions provider.
|
|
14
|
+
|
|
15
|
+
Supports Llama, Mixtral, and Gemma models via Groq's
|
|
16
|
+
OpenAI-compatible API with streaming and tool calling.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
ENV_KEY = "GROQ_API_KEY"
|
|
20
|
+
DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
base_url: str | None = None,
|
|
26
|
+
model: str = "llama-3.3-70b-versatile",
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initialize the Groq provider.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
api_key: Groq API key. Falls back to GROQ_API_KEY env var.
|
|
33
|
+
base_url: Custom base URL for API requests.
|
|
34
|
+
model: Model name to use. Defaults to "llama-3.3-70b-versatile".
|
|
35
|
+
**kwargs: Additional provider-specific arguments.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def name(self) -> str:
|
|
41
|
+
"""Get the provider name."""
|
|
42
|
+
return "groq"
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def models(self) -> list[str]:
|
|
46
|
+
"""Get the list of supported model names."""
|
|
47
|
+
return [
|
|
48
|
+
"llama-3.3-70b-versatile",
|
|
49
|
+
"llama-3.1-8b-instant",
|
|
50
|
+
"mixtral-8x7b-32768",
|
|
51
|
+
"gemma2-9b-it",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def context_limit(self) -> int:
|
|
56
|
+
"""Get the maximum context length in tokens."""
|
|
57
|
+
return 131072 # 128K
|
|
58
|
+
|
|
59
|
+
def _get_api_url(self) -> str:
|
|
60
|
+
"""Get the API URL for chat completions."""
|
|
61
|
+
base = self._base_url or self.DEFAULT_BASE_URL
|
|
62
|
+
return f"{base}/chat/completions"
|
|
63
|
+
|
|
64
|
+
def count_tokens(
|
|
65
|
+
self,
|
|
66
|
+
messages: list,
|
|
67
|
+
system: str | None = None,
|
|
68
|
+
) -> int:
|
|
69
|
+
"""Count tokens in the messages.
|
|
70
|
+
|
|
71
|
+
Uses character-based estimation since tiktoken may not have
|
|
72
|
+
encodings for Groq-specific models.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
messages: The conversation messages.
|
|
76
|
+
system: Optional system prompt.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The estimated token count.
|
|
80
|
+
"""
|
|
81
|
+
# Fall back to character-based estimation (roughly 4 chars per token)
|
|
82
|
+
total_chars = 0
|
|
83
|
+
if system:
|
|
84
|
+
total_chars += len(system)
|
|
85
|
+
for msg in messages:
|
|
86
|
+
if isinstance(msg.content, str):
|
|
87
|
+
total_chars += len(msg.content)
|
|
88
|
+
else:
|
|
89
|
+
for block in msg.content:
|
|
90
|
+
if hasattr(block, "text"):
|
|
91
|
+
total_chars += len(block.text)
|
|
92
|
+
return max(1, total_chars // 4)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
__all__ = ["GroqProvider"]
|
|
96
|
+
|