daita-agents 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/__init__.py +216 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +743 -0
- daita/agents/substrate.py +1141 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +481 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +779 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +459 -0
- daita/core/tools.py +554 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1144 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +291 -0
- daita/llm/base.py +530 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +355 -0
- daita/llm/grok.py +219 -0
- daita/llm/mock.py +172 -0
- daita/llm/openai.py +220 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +849 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +520 -0
- daita/plugins/mysql.py +362 -0
- daita/plugins/postgresql.py +342 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +537 -0
- daita/plugins/s3.py +770 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.2.0.dist-info/METADATA +409 -0
- daita_agents-0.2.0.dist-info/RECORD +69 -0
- daita_agents-0.2.0.dist-info/WHEEL +5 -0
- daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/llm/gemini.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Gemini LLM provider implementation with integrated tracing.
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
import asyncio
|
|
7
|
+
from typing import Dict, Any, Optional, List
|
|
8
|
+
|
|
9
|
+
from ..core.exceptions import LLMError
|
|
10
|
+
from .base import BaseLLMProvider
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class GeminiProvider(BaseLLMProvider):
|
|
15
|
+
"""Google Gemini LLM provider implementation with automatic call tracing."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model: str = "gemini-2.5-flash",
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
**kwargs
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize Gemini provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Gemini model name (e.g., "gemini-2.5-flash", "gemini-2.5-pro", "gemini-2.0-flash")
|
|
28
|
+
api_key: Google AI API key
|
|
29
|
+
**kwargs: Additional Gemini-specific parameters
|
|
30
|
+
"""
|
|
31
|
+
# Get API key from parameter or environment
|
|
32
|
+
api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
|
33
|
+
|
|
34
|
+
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
35
|
+
|
|
36
|
+
# Gemini-specific default parameters
|
|
37
|
+
self.default_params.update({
|
|
38
|
+
'timeout': kwargs.get('timeout', 60),
|
|
39
|
+
'safety_settings': kwargs.get('safety_settings', None),
|
|
40
|
+
'generation_config': kwargs.get('generation_config', None)
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
# Lazy-load Gemini client
|
|
44
|
+
self._client = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def client(self):
|
|
48
|
+
"""Lazy-load Google Generative AI client."""
|
|
49
|
+
if self._client is None:
|
|
50
|
+
try:
|
|
51
|
+
import google.generativeai as genai
|
|
52
|
+
self._validate_api_key()
|
|
53
|
+
|
|
54
|
+
# Configure the API key
|
|
55
|
+
genai.configure(api_key=self.api_key)
|
|
56
|
+
|
|
57
|
+
# Create the generative model
|
|
58
|
+
self._client = genai.GenerativeModel(self.model)
|
|
59
|
+
logger.debug("Gemini client initialized")
|
|
60
|
+
except ImportError:
|
|
61
|
+
raise LLMError(
|
|
62
|
+
"Google Generative AI package not installed. Install with: pip install google-generativeai"
|
|
63
|
+
)
|
|
64
|
+
return self._client
|
|
65
|
+
|
|
66
|
+
async def _generate_impl(self, prompt: str, **kwargs) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Provider-specific implementation of text generation for Gemini.
|
|
69
|
+
|
|
70
|
+
This method contains the actual Gemini API call logic and is automatically
|
|
71
|
+
wrapped with tracing by the base class generate() method.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
prompt: Input prompt
|
|
75
|
+
**kwargs: Optional parameters
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Generated text response
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
# Merge parameters
|
|
82
|
+
params = self._merge_params(kwargs)
|
|
83
|
+
|
|
84
|
+
# Prepare generation config
|
|
85
|
+
generation_config = params.get('generation_config', {})
|
|
86
|
+
if not generation_config:
|
|
87
|
+
# Gemini requires max_output_tokens to be set explicitly
|
|
88
|
+
max_tokens = params.get('max_tokens')
|
|
89
|
+
if max_tokens is None:
|
|
90
|
+
max_tokens = 2048 # Reasonable default for Gemini
|
|
91
|
+
|
|
92
|
+
generation_config = {
|
|
93
|
+
'max_output_tokens': max_tokens,
|
|
94
|
+
'temperature': params.get('temperature'),
|
|
95
|
+
'top_p': params.get('top_p')
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Make API call (Gemini's generate_content can be sync or async)
|
|
99
|
+
# For consistency with other providers, we'll run in executor if needed
|
|
100
|
+
if asyncio.iscoroutinefunction(self.client.generate_content):
|
|
101
|
+
response = await self.client.generate_content(
|
|
102
|
+
prompt,
|
|
103
|
+
generation_config=generation_config,
|
|
104
|
+
safety_settings=params.get('safety_settings')
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
# Run synchronous method in executor
|
|
108
|
+
loop = asyncio.get_event_loop()
|
|
109
|
+
response = await loop.run_in_executor(
|
|
110
|
+
None,
|
|
111
|
+
lambda: self.client.generate_content(
|
|
112
|
+
prompt,
|
|
113
|
+
generation_config=generation_config,
|
|
114
|
+
safety_settings=params.get('safety_settings')
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Store usage info if available (Gemini's usage tracking varies)
|
|
119
|
+
if hasattr(response, 'usage_metadata'):
|
|
120
|
+
self._last_usage = response.usage_metadata
|
|
121
|
+
|
|
122
|
+
# Handle blocked or empty responses
|
|
123
|
+
if not response.parts:
|
|
124
|
+
finish_reason = response.candidates[0].finish_reason if response.candidates else None
|
|
125
|
+
if finish_reason == 2: # MAX_TOKENS
|
|
126
|
+
logger.warning("Gemini response hit max_tokens limit, returning partial response")
|
|
127
|
+
return "[Response truncated due to token limit]"
|
|
128
|
+
elif finish_reason == 3: # SAFETY
|
|
129
|
+
logger.warning("Gemini response blocked by safety filters")
|
|
130
|
+
return "[Response blocked by safety filters]"
|
|
131
|
+
else:
|
|
132
|
+
logger.warning(f"Gemini returned empty response with finish_reason: {finish_reason}")
|
|
133
|
+
return "[Empty response from Gemini]"
|
|
134
|
+
|
|
135
|
+
return response.text
|
|
136
|
+
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Gemini generation failed: {str(e)}")
|
|
139
|
+
raise LLMError(f"Gemini generation failed: {str(e)}")
|
|
140
|
+
|
|
141
|
+
def _get_last_token_usage(self) -> Dict[str, int]:
|
|
142
|
+
"""
|
|
143
|
+
Override base class method to handle Gemini's token format.
|
|
144
|
+
|
|
145
|
+
Gemini uses different token field names in usage_metadata.
|
|
146
|
+
"""
|
|
147
|
+
if self._last_usage:
|
|
148
|
+
# Gemini format varies, try to extract what we can
|
|
149
|
+
prompt_tokens = getattr(self._last_usage, 'prompt_token_count', 0)
|
|
150
|
+
completion_tokens = getattr(self._last_usage, 'candidates_token_count', 0)
|
|
151
|
+
total_tokens = getattr(self._last_usage, 'total_token_count', prompt_tokens + completion_tokens)
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
'total_tokens': total_tokens,
|
|
155
|
+
'prompt_tokens': prompt_tokens,
|
|
156
|
+
'completion_tokens': completion_tokens
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
# Fallback to base class estimation
|
|
160
|
+
return super()._get_last_token_usage()
|
|
161
|
+
|
|
162
|
+
def _convert_tools_to_format(self, tools: List['AgentTool']) -> List[Dict[str, Any]]:
|
|
163
|
+
"""
|
|
164
|
+
Convert AgentTool list to Gemini function declaration format.
|
|
165
|
+
|
|
166
|
+
Gemini uses a simpler format than OpenAI.
|
|
167
|
+
"""
|
|
168
|
+
gemini_tools = []
|
|
169
|
+
for tool in tools:
|
|
170
|
+
openai_format = tool.to_openai_function()
|
|
171
|
+
|
|
172
|
+
# Convert OpenAI format to Gemini format
|
|
173
|
+
gemini_tools.append({
|
|
174
|
+
"name": openai_format["function"]["name"],
|
|
175
|
+
"description": openai_format["function"]["description"],
|
|
176
|
+
"parameters": openai_format["function"]["parameters"]
|
|
177
|
+
})
|
|
178
|
+
|
|
179
|
+
return gemini_tools
|
|
180
|
+
|
|
181
|
+
def _convert_messages_to_gemini(
|
|
182
|
+
self,
|
|
183
|
+
messages: List[Dict[str, Any]]
|
|
184
|
+
) -> List[Dict[str, Any]]:
|
|
185
|
+
"""
|
|
186
|
+
Convert universal flat format to Gemini's format.
|
|
187
|
+
|
|
188
|
+
Gemini uses "user" and "model" roles (not "assistant").
|
|
189
|
+
"""
|
|
190
|
+
import google.generativeai.types as genai_types
|
|
191
|
+
|
|
192
|
+
gemini_messages = []
|
|
193
|
+
|
|
194
|
+
for msg in messages:
|
|
195
|
+
if msg["role"] == "user":
|
|
196
|
+
gemini_messages.append({
|
|
197
|
+
"role": "user",
|
|
198
|
+
"parts": [msg["content"]]
|
|
199
|
+
})
|
|
200
|
+
elif msg["role"] == "assistant":
|
|
201
|
+
if msg.get("tool_calls"):
|
|
202
|
+
# Assistant with tool calls
|
|
203
|
+
parts = []
|
|
204
|
+
for tc in msg["tool_calls"]:
|
|
205
|
+
parts.append(genai_types.FunctionCall(
|
|
206
|
+
name=tc["name"],
|
|
207
|
+
args=tc["arguments"]
|
|
208
|
+
))
|
|
209
|
+
gemini_messages.append({
|
|
210
|
+
"role": "model",
|
|
211
|
+
"parts": parts
|
|
212
|
+
})
|
|
213
|
+
else:
|
|
214
|
+
# Regular assistant message
|
|
215
|
+
gemini_messages.append({
|
|
216
|
+
"role": "model",
|
|
217
|
+
"parts": [msg.get("content", "")]
|
|
218
|
+
})
|
|
219
|
+
elif msg["role"] == "tool":
|
|
220
|
+
# Tool result
|
|
221
|
+
gemini_messages.append({
|
|
222
|
+
"role": "function",
|
|
223
|
+
"parts": [genai_types.FunctionResponse(
|
|
224
|
+
name=msg.get("name", ""),
|
|
225
|
+
response={"result": msg["content"]}
|
|
226
|
+
)]
|
|
227
|
+
})
|
|
228
|
+
|
|
229
|
+
return gemini_messages
|
|
230
|
+
|
|
231
|
+
async def _generate_with_tools_single(
|
|
232
|
+
self,
|
|
233
|
+
messages: List[Dict[str, Any]],
|
|
234
|
+
tools: List[Dict[str, Any]],
|
|
235
|
+
**kwargs
|
|
236
|
+
) -> Dict[str, Any]:
|
|
237
|
+
"""
|
|
238
|
+
Gemini-specific tool calling implementation.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
messages: Conversation history in universal flat format
|
|
242
|
+
tools: Tool specifications in Gemini format
|
|
243
|
+
**kwargs: Optional parameters
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
{
|
|
247
|
+
"tool_calls": [...], # If LLM wants to call tools
|
|
248
|
+
"content": "...", # If LLM has final answer
|
|
249
|
+
}
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
import google.generativeai as genai
|
|
253
|
+
from google.generativeai.types import FunctionDeclaration, Tool
|
|
254
|
+
|
|
255
|
+
# Merge parameters
|
|
256
|
+
params = self._merge_params(kwargs)
|
|
257
|
+
|
|
258
|
+
# Convert tools to Gemini FunctionDeclaration format
|
|
259
|
+
function_declarations = [
|
|
260
|
+
FunctionDeclaration(
|
|
261
|
+
name=tool["name"],
|
|
262
|
+
description=tool["description"],
|
|
263
|
+
parameters=tool["parameters"]
|
|
264
|
+
)
|
|
265
|
+
for tool in tools
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
# Create Tool object
|
|
269
|
+
gemini_tool = Tool(function_declarations=function_declarations)
|
|
270
|
+
|
|
271
|
+
# Prepare generation config
|
|
272
|
+
generation_config = params.get('generation_config', {})
|
|
273
|
+
if not generation_config:
|
|
274
|
+
# Gemini requires max_output_tokens to be set explicitly
|
|
275
|
+
max_tokens = params.get('max_tokens')
|
|
276
|
+
if max_tokens is None:
|
|
277
|
+
max_tokens = 2048 # Reasonable default for Gemini
|
|
278
|
+
|
|
279
|
+
generation_config = {
|
|
280
|
+
'max_output_tokens': max_tokens,
|
|
281
|
+
'temperature': params.get('temperature'),
|
|
282
|
+
'top_p': params.get('top_p')
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
# Convert messages to Gemini format
|
|
286
|
+
gemini_messages = self._convert_messages_to_gemini(messages)
|
|
287
|
+
|
|
288
|
+
# Build conversation content
|
|
289
|
+
# For Gemini, we need to structure the chat differently
|
|
290
|
+
# The first message should be the system/user prompt
|
|
291
|
+
if gemini_messages:
|
|
292
|
+
# Start a chat with history
|
|
293
|
+
chat = self.client.start_chat(history=gemini_messages[:-1] if len(gemini_messages) > 1 else [])
|
|
294
|
+
last_message = gemini_messages[-1]["parts"][0] if gemini_messages else ""
|
|
295
|
+
else:
|
|
296
|
+
chat = self.client.start_chat()
|
|
297
|
+
last_message = ""
|
|
298
|
+
|
|
299
|
+
# Make API call with tools
|
|
300
|
+
if asyncio.iscoroutinefunction(chat.send_message):
|
|
301
|
+
response = await chat.send_message(
|
|
302
|
+
last_message,
|
|
303
|
+
tools=[gemini_tool],
|
|
304
|
+
generation_config=generation_config
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
loop = asyncio.get_event_loop()
|
|
308
|
+
response = await loop.run_in_executor(
|
|
309
|
+
None,
|
|
310
|
+
lambda: chat.send_message(
|
|
311
|
+
last_message,
|
|
312
|
+
tools=[gemini_tool],
|
|
313
|
+
generation_config=generation_config
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Store usage for token tracking
|
|
318
|
+
if hasattr(response, 'usage_metadata'):
|
|
319
|
+
self._last_usage = response.usage_metadata
|
|
320
|
+
|
|
321
|
+
# Check for function calls in response
|
|
322
|
+
function_calls = []
|
|
323
|
+
for part in response.parts:
|
|
324
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
325
|
+
fc = part.function_call
|
|
326
|
+
function_calls.append({
|
|
327
|
+
"id": f"call_{len(function_calls)}",
|
|
328
|
+
"name": fc.name,
|
|
329
|
+
"arguments": dict(fc.args)
|
|
330
|
+
})
|
|
331
|
+
|
|
332
|
+
if function_calls:
|
|
333
|
+
# LLM wants to call tools
|
|
334
|
+
return {
|
|
335
|
+
"tool_calls": function_calls
|
|
336
|
+
}
|
|
337
|
+
else:
|
|
338
|
+
# LLM has final answer
|
|
339
|
+
return {
|
|
340
|
+
"content": response.text
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.error(f"Gemini tool calling failed: {str(e)}")
|
|
345
|
+
raise LLMError(f"Gemini tool calling failed: {str(e)}")
|
|
346
|
+
|
|
347
|
+
@property
|
|
348
|
+
def info(self) -> Dict[str, Any]:
|
|
349
|
+
"""Get information about the Gemini provider."""
|
|
350
|
+
base_info = super().info
|
|
351
|
+
base_info.update({
|
|
352
|
+
'provider_name': 'Google Gemini',
|
|
353
|
+
'api_compatible': 'Google AI'
|
|
354
|
+
})
|
|
355
|
+
return base_info
|
daita/llm/grok.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Grok (xAI) LLM provider implementation with integrated tracing.
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, Any, Optional, List
|
|
7
|
+
|
|
8
|
+
from ..core.exceptions import LLMError
|
|
9
|
+
from .base import BaseLLMProvider
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class GrokProvider(BaseLLMProvider):
|
|
14
|
+
"""Grok (xAI) LLM provider implementation with automatic call tracing."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model: str = "grok-3",
|
|
19
|
+
api_key: Optional[str] = None,
|
|
20
|
+
**kwargs
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Initialize Grok provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: Grok model name (e.g., "grok-3", "grok-vision-beta")
|
|
27
|
+
api_key: xAI API key
|
|
28
|
+
**kwargs: Additional Grok-specific parameters
|
|
29
|
+
"""
|
|
30
|
+
# Get API key from parameter or environment
|
|
31
|
+
api_key = api_key or os.getenv("XAI_API_KEY") or os.getenv("GROK_API_KEY")
|
|
32
|
+
|
|
33
|
+
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
34
|
+
|
|
35
|
+
# Grok-specific default parameters
|
|
36
|
+
self.default_params.update({
|
|
37
|
+
'stream': kwargs.get('stream', False),
|
|
38
|
+
'timeout': kwargs.get('timeout', 60)
|
|
39
|
+
})
|
|
40
|
+
|
|
41
|
+
# Base URL for xAI API
|
|
42
|
+
self.base_url = kwargs.get('base_url', 'https://api.x.ai/v1')
|
|
43
|
+
|
|
44
|
+
# Lazy-load OpenAI client (Grok uses OpenAI-compatible API)
|
|
45
|
+
self._client = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def client(self):
|
|
49
|
+
"""Lazy-load OpenAI client configured for xAI."""
|
|
50
|
+
if self._client is None:
|
|
51
|
+
try:
|
|
52
|
+
import openai
|
|
53
|
+
self._validate_api_key()
|
|
54
|
+
self._client = openai.AsyncOpenAI(
|
|
55
|
+
api_key=self.api_key,
|
|
56
|
+
base_url=self.base_url
|
|
57
|
+
)
|
|
58
|
+
logger.debug("Grok client initialized")
|
|
59
|
+
except ImportError:
|
|
60
|
+
raise LLMError(
|
|
61
|
+
"OpenAI package not installed. Install with: pip install openai"
|
|
62
|
+
)
|
|
63
|
+
return self._client
|
|
64
|
+
|
|
65
|
+
def _convert_messages_to_openai(
|
|
66
|
+
self,
|
|
67
|
+
messages: List[Dict[str, Any]]
|
|
68
|
+
) -> List[Dict[str, Any]]:
|
|
69
|
+
"""
|
|
70
|
+
Convert universal flat format to OpenAI's nested format.
|
|
71
|
+
|
|
72
|
+
Grok uses OpenAI-compatible API, so we need the same conversion.
|
|
73
|
+
"""
|
|
74
|
+
import json
|
|
75
|
+
|
|
76
|
+
openai_messages = []
|
|
77
|
+
for msg in messages:
|
|
78
|
+
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
79
|
+
# Convert flat format to OpenAI's nested format
|
|
80
|
+
converted_tool_calls = []
|
|
81
|
+
for tc in msg["tool_calls"]:
|
|
82
|
+
converted_tool_calls.append({
|
|
83
|
+
"id": tc.get("id", ""),
|
|
84
|
+
"type": "function",
|
|
85
|
+
"function": {
|
|
86
|
+
"name": tc["name"],
|
|
87
|
+
"arguments": json.dumps(tc["arguments"]) if isinstance(tc["arguments"], dict) else tc["arguments"]
|
|
88
|
+
}
|
|
89
|
+
})
|
|
90
|
+
|
|
91
|
+
openai_messages.append({
|
|
92
|
+
"role": "assistant",
|
|
93
|
+
"tool_calls": converted_tool_calls
|
|
94
|
+
})
|
|
95
|
+
else:
|
|
96
|
+
# Pass through other messages unchanged
|
|
97
|
+
openai_messages.append(msg)
|
|
98
|
+
|
|
99
|
+
return openai_messages
|
|
100
|
+
|
|
101
|
+
async def _generate_with_tools_single(
|
|
102
|
+
self,
|
|
103
|
+
messages: List[Dict[str, Any]],
|
|
104
|
+
tools: List[Dict[str, Any]],
|
|
105
|
+
**kwargs
|
|
106
|
+
) -> Dict[str, Any]:
|
|
107
|
+
"""
|
|
108
|
+
Grok tool calling implementation (uses OpenAI-compatible API).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
messages: Conversation history in universal flat format
|
|
112
|
+
tools: Tool specifications in OpenAI format
|
|
113
|
+
**kwargs: Optional parameters
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
{
|
|
117
|
+
"tool_calls": [...], # If LLM wants to call tools
|
|
118
|
+
"content": "...", # If LLM has final answer
|
|
119
|
+
}
|
|
120
|
+
"""
|
|
121
|
+
import json
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
# Merge parameters
|
|
125
|
+
params = self._merge_params(kwargs)
|
|
126
|
+
|
|
127
|
+
# Convert flat format to OpenAI's nested format
|
|
128
|
+
openai_messages = self._convert_messages_to_openai(messages)
|
|
129
|
+
|
|
130
|
+
# Make API call with tools (Grok uses OpenAI-compatible interface)
|
|
131
|
+
response = await self.client.chat.completions.create(
|
|
132
|
+
model=self.model,
|
|
133
|
+
messages=openai_messages,
|
|
134
|
+
tools=tools,
|
|
135
|
+
tool_choice="auto",
|
|
136
|
+
max_tokens=params.get('max_tokens'),
|
|
137
|
+
temperature=params.get('temperature'),
|
|
138
|
+
top_p=params.get('top_p'),
|
|
139
|
+
timeout=params.get('timeout')
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
message = response.choices[0].message
|
|
143
|
+
|
|
144
|
+
# Store usage for token tracking
|
|
145
|
+
if hasattr(response, 'usage'):
|
|
146
|
+
self._last_usage = response.usage
|
|
147
|
+
|
|
148
|
+
if message.tool_calls:
|
|
149
|
+
# LLM wants to call tools - return in flat format
|
|
150
|
+
return {
|
|
151
|
+
"tool_calls": [
|
|
152
|
+
{
|
|
153
|
+
"id": tc.id,
|
|
154
|
+
"name": tc.function.name,
|
|
155
|
+
"arguments": json.loads(tc.function.arguments)
|
|
156
|
+
}
|
|
157
|
+
for tc in message.tool_calls
|
|
158
|
+
]
|
|
159
|
+
}
|
|
160
|
+
else:
|
|
161
|
+
# LLM has final answer
|
|
162
|
+
return {
|
|
163
|
+
"content": message.content
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Grok tool calling failed: {str(e)}")
|
|
168
|
+
raise LLMError(f"Grok tool calling failed: {str(e)}")
|
|
169
|
+
|
|
170
|
+
async def _generate_impl(self, prompt: str, **kwargs) -> str:
|
|
171
|
+
"""
|
|
172
|
+
Provider-specific implementation of text generation for Grok.
|
|
173
|
+
|
|
174
|
+
This method contains the actual Grok API call logic and is automatically
|
|
175
|
+
wrapped with tracing by the base class generate() method.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
prompt: Input prompt
|
|
179
|
+
**kwargs: Optional parameters
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Generated text response
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
# Merge parameters
|
|
186
|
+
params = self._merge_params(kwargs)
|
|
187
|
+
|
|
188
|
+
# Make API call using OpenAI-compatible interface
|
|
189
|
+
response = await self.client.chat.completions.create(
|
|
190
|
+
model=self.model,
|
|
191
|
+
messages=[
|
|
192
|
+
{"role": "user", "content": prompt}
|
|
193
|
+
],
|
|
194
|
+
max_tokens=params.get('max_tokens'),
|
|
195
|
+
temperature=params.get('temperature'),
|
|
196
|
+
top_p=params.get('top_p'),
|
|
197
|
+
stream=params.get('stream'),
|
|
198
|
+
timeout=params.get('timeout')
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Store usage for base class token extraction
|
|
202
|
+
self._last_usage = response.usage
|
|
203
|
+
|
|
204
|
+
return response.choices[0].message.content
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Grok generation failed: {str(e)}")
|
|
208
|
+
raise LLMError(f"Grok generation failed: {str(e)}")
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def info(self) -> Dict[str, Any]:
|
|
212
|
+
"""Get information about the Grok provider."""
|
|
213
|
+
base_info = super().info
|
|
214
|
+
base_info.update({
|
|
215
|
+
'base_url': self.base_url,
|
|
216
|
+
'provider_name': 'Grok (xAI)',
|
|
217
|
+
'api_compatible': 'OpenAI'
|
|
218
|
+
})
|
|
219
|
+
return base_info
|