daita-agents 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/__init__.py +216 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +743 -0
- daita/agents/substrate.py +1141 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +481 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +779 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +459 -0
- daita/core/tools.py +554 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1144 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +291 -0
- daita/llm/base.py +530 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +355 -0
- daita/llm/grok.py +219 -0
- daita/llm/mock.py +172 -0
- daita/llm/openai.py +220 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +849 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +520 -0
- daita/plugins/mysql.py +362 -0
- daita/plugins/postgresql.py +342 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +537 -0
- daita/plugins/s3.py +770 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.2.0.dist-info/METADATA +409 -0
- daita_agents-0.2.0.dist-info/RECORD +69 -0
- daita_agents-0.2.0.dist-info/WHEEL +5 -0
- daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/llm/base.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Updated BaseLLMProvider with Unified Tracing Integration
|
|
3
|
+
|
|
4
|
+
This replaces the old BaseLLMProvider to use the unified tracing system.
|
|
5
|
+
All LLM calls are automatically traced without user configuration.
|
|
6
|
+
|
|
7
|
+
Key Changes:
|
|
8
|
+
- Removed old token tracking system completely
|
|
9
|
+
- Integrated automatic LLM call tracing
|
|
10
|
+
- Simple cost estimation
|
|
11
|
+
- Automatic provider/model/token capture
|
|
12
|
+
- Zero configuration required
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Dict, Any, Optional, List
|
|
17
|
+
from contextlib import asynccontextmanager
|
|
18
|
+
import logging
|
|
19
|
+
import time
|
|
20
|
+
import asyncio
|
|
21
|
+
|
|
22
|
+
from ..core.tracing import get_trace_manager, TraceType, TraceStatus
|
|
23
|
+
from ..core.interfaces import LLMProvider
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
class BaseLLMProvider(LLMProvider, ABC):
|
|
28
|
+
"""
|
|
29
|
+
Base class for LLM providers with automatic call tracing.
|
|
30
|
+
|
|
31
|
+
Every LLM call is automatically traced with:
|
|
32
|
+
- Provider and model details
|
|
33
|
+
- Token usage and costs
|
|
34
|
+
- Latency and performance
|
|
35
|
+
- Input/output content (preview)
|
|
36
|
+
- Error tracking
|
|
37
|
+
|
|
38
|
+
Users get full LLM observability without any configuration.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model: str,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
**kwargs
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize the LLM provider with automatic tracing.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: Model identifier
|
|
52
|
+
api_key: API key for authentication
|
|
53
|
+
**kwargs: Additional provider-specific options
|
|
54
|
+
"""
|
|
55
|
+
self.model = model
|
|
56
|
+
self.api_key = api_key
|
|
57
|
+
self.config = kwargs
|
|
58
|
+
|
|
59
|
+
# Default parameters
|
|
60
|
+
self.default_params = {
|
|
61
|
+
'temperature': kwargs.get('temperature', 0.7),
|
|
62
|
+
'max_tokens': kwargs.get('max_tokens', 1000),
|
|
63
|
+
'top_p': kwargs.get('top_p', 1.0),
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Agent ID for tracing (set by agent)
|
|
67
|
+
self.agent_id = kwargs.get('agent_id')
|
|
68
|
+
|
|
69
|
+
# Get trace manager for automatic tracing
|
|
70
|
+
self.trace_manager = get_trace_manager()
|
|
71
|
+
|
|
72
|
+
# Provider name for tracing
|
|
73
|
+
self.provider_name = self.__class__.__name__.replace('Provider', '').lower()
|
|
74
|
+
|
|
75
|
+
# Last usage for cost estimation
|
|
76
|
+
self._last_usage = None
|
|
77
|
+
|
|
78
|
+
logger.debug(f"Initialized {self.__class__.__name__} with model {model} (automatic tracing enabled)")
|
|
79
|
+
|
|
80
|
+
async def generate(self, prompt: str, **kwargs) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Generate text from prompt with automatic LLM call tracing.
|
|
83
|
+
|
|
84
|
+
Every call is automatically traced with full metadata.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
prompt: Input prompt
|
|
88
|
+
**kwargs: Optional parameters to override defaults
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Generated text response
|
|
92
|
+
"""
|
|
93
|
+
# Merge parameters
|
|
94
|
+
params = self._merge_params(kwargs)
|
|
95
|
+
|
|
96
|
+
# Automatically trace the LLM call
|
|
97
|
+
async with self.trace_manager.span(
|
|
98
|
+
operation_name=f"llm_{self.provider_name}_{self.model}",
|
|
99
|
+
trace_type=TraceType.LLM_CALL,
|
|
100
|
+
agent_id=self.agent_id,
|
|
101
|
+
input_data=prompt,
|
|
102
|
+
llm_provider=self.provider_name,
|
|
103
|
+
llm_model=self.model,
|
|
104
|
+
temperature=str(params.get('temperature', 0.7)),
|
|
105
|
+
max_tokens=str(params.get('max_tokens', 1000))
|
|
106
|
+
) as span_id:
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
# Execute the actual LLM call
|
|
112
|
+
response = await self._generate_impl(prompt, **params)
|
|
113
|
+
|
|
114
|
+
end_time = time.time()
|
|
115
|
+
duration_ms = (end_time - start_time) * 1000
|
|
116
|
+
|
|
117
|
+
# Get token usage from the call
|
|
118
|
+
token_usage = self._get_last_token_usage()
|
|
119
|
+
|
|
120
|
+
# Record LLM call details in the trace
|
|
121
|
+
self.trace_manager.record_llm_call(
|
|
122
|
+
span_id=span_id,
|
|
123
|
+
provider=self.provider_name,
|
|
124
|
+
model=self.model,
|
|
125
|
+
tokens=token_usage
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Add cost estimation if available
|
|
129
|
+
cost = self._estimate_cost(token_usage)
|
|
130
|
+
if cost and span_id in self.trace_manager._active_spans:
|
|
131
|
+
span = self.trace_manager._active_spans[span_id]
|
|
132
|
+
span.metadata['estimated_cost_usd'] = cost
|
|
133
|
+
span.metadata['duration_ms'] = duration_ms
|
|
134
|
+
|
|
135
|
+
logger.debug(f"LLM call completed: {token_usage.get('total_tokens', 0)} tokens in {duration_ms:.1f}ms")
|
|
136
|
+
return response
|
|
137
|
+
|
|
138
|
+
except Exception as e:
|
|
139
|
+
# LLM call failed - error automatically recorded by span context
|
|
140
|
+
logger.warning(f"LLM call failed: {str(e)}")
|
|
141
|
+
raise
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
async def _generate_impl(self, prompt: str, **kwargs) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Provider-specific implementation of text generation.
|
|
147
|
+
|
|
148
|
+
This method must be implemented by each provider (OpenAI, Anthropic, etc.)
|
|
149
|
+
and contains the actual LLM API call logic.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
prompt: Input prompt
|
|
153
|
+
**kwargs: Optional parameters
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Generated text response
|
|
157
|
+
"""
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
def _get_last_token_usage(self) -> Dict[str, int]:
|
|
161
|
+
"""
|
|
162
|
+
Get token usage from the last API call.
|
|
163
|
+
|
|
164
|
+
This should be overridden by each provider to return actual
|
|
165
|
+
token usage from their last API response.
|
|
166
|
+
"""
|
|
167
|
+
if self._last_usage:
|
|
168
|
+
# Try to extract from stored usage object
|
|
169
|
+
if hasattr(self._last_usage, 'total_tokens'):
|
|
170
|
+
# OpenAI format
|
|
171
|
+
return {
|
|
172
|
+
'total_tokens': getattr(self._last_usage, 'total_tokens', 0),
|
|
173
|
+
'prompt_tokens': getattr(self._last_usage, 'prompt_tokens', 0),
|
|
174
|
+
'completion_tokens': getattr(self._last_usage, 'completion_tokens', 0)
|
|
175
|
+
}
|
|
176
|
+
elif hasattr(self._last_usage, 'input_tokens'):
|
|
177
|
+
# Anthropic format
|
|
178
|
+
input_tokens = getattr(self._last_usage, 'input_tokens', 0)
|
|
179
|
+
output_tokens = getattr(self._last_usage, 'output_tokens', 0)
|
|
180
|
+
return {
|
|
181
|
+
'total_tokens': input_tokens + output_tokens,
|
|
182
|
+
'prompt_tokens': input_tokens,
|
|
183
|
+
'completion_tokens': output_tokens
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
# Default fallback
|
|
187
|
+
return {
|
|
188
|
+
'total_tokens': 0,
|
|
189
|
+
'prompt_tokens': 0,
|
|
190
|
+
'completion_tokens': 0
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def _estimate_cost(self, token_usage: Dict[str, int]) -> Optional[float]:
|
|
194
|
+
"""
|
|
195
|
+
Simple cost estimation for MVP.
|
|
196
|
+
|
|
197
|
+
Providers can override with specific pricing.
|
|
198
|
+
"""
|
|
199
|
+
total_tokens = token_usage.get('total_tokens', 0)
|
|
200
|
+
if total_tokens == 0:
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Generic estimation for MVP - providers should override
|
|
204
|
+
cost_per_1k_tokens = 0.002 # $0.002 per 1K tokens
|
|
205
|
+
return (total_tokens / 1000) * cost_per_1k_tokens
|
|
206
|
+
|
|
207
|
+
def _merge_params(self, override_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
208
|
+
"""Merge default parameters with overrides."""
|
|
209
|
+
params = self.default_params.copy()
|
|
210
|
+
params.update(override_params)
|
|
211
|
+
return params
|
|
212
|
+
|
|
213
|
+
def _validate_api_key(self) -> None:
|
|
214
|
+
"""Validate that API key is available."""
|
|
215
|
+
if not self.api_key:
|
|
216
|
+
raise ValueError(f"API key required for {self.__class__.__name__}")
|
|
217
|
+
|
|
218
|
+
def set_agent_id(self, agent_id: str):
|
|
219
|
+
"""
|
|
220
|
+
Set the agent ID for tracing context.
|
|
221
|
+
|
|
222
|
+
This is called automatically by BaseAgent during initialization.
|
|
223
|
+
"""
|
|
224
|
+
self.agent_id = agent_id
|
|
225
|
+
logger.debug(f"Set agent ID {agent_id} for {self.provider_name} provider")
|
|
226
|
+
|
|
227
|
+
def get_recent_calls(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
228
|
+
"""Get recent LLM calls for this provider's agent from unified tracing."""
|
|
229
|
+
if not self.agent_id:
|
|
230
|
+
return []
|
|
231
|
+
|
|
232
|
+
operations = self.trace_manager.get_recent_operations(agent_id=self.agent_id, limit=limit * 2)
|
|
233
|
+
|
|
234
|
+
# Filter for LLM calls from this provider
|
|
235
|
+
llm_calls = [
|
|
236
|
+
op for op in operations
|
|
237
|
+
if (op.get('type') == 'llm_call' and
|
|
238
|
+
op.get('metadata', {}).get('llm_provider') == self.provider_name)
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
return llm_calls[:limit]
|
|
242
|
+
|
|
243
|
+
def get_token_stats(self) -> Dict[str, Any]:
|
|
244
|
+
"""Get token usage statistics from unified tracing."""
|
|
245
|
+
if not self.agent_id:
|
|
246
|
+
return {
|
|
247
|
+
'total_calls': 0,
|
|
248
|
+
'total_tokens': 0,
|
|
249
|
+
'estimated_cost': 0.0
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
metrics = self.trace_manager.get_agent_metrics(self.agent_id)
|
|
253
|
+
|
|
254
|
+
# Get actual token usage from the most recent call
|
|
255
|
+
last_usage = self._get_last_token_usage()
|
|
256
|
+
|
|
257
|
+
return {
|
|
258
|
+
'total_calls': metrics.get('total_operations', 0), # All operations
|
|
259
|
+
'total_tokens': last_usage.get('total_tokens', 0), # From last API call
|
|
260
|
+
'prompt_tokens': last_usage.get('prompt_tokens', 0),
|
|
261
|
+
'completion_tokens': last_usage.get('completion_tokens', 0),
|
|
262
|
+
'estimated_cost': self._estimate_cost(last_usage) or 0.0,
|
|
263
|
+
'success_rate': metrics.get('success_rate', 0),
|
|
264
|
+
'avg_latency_ms': metrics.get('avg_latency_ms', 0)
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
async def generate_with_tools(
|
|
268
|
+
self,
|
|
269
|
+
prompt: str,
|
|
270
|
+
tools: List['AgentTool'],
|
|
271
|
+
max_iterations: int = 5,
|
|
272
|
+
**kwargs
|
|
273
|
+
) -> Dict[str, Any]:
|
|
274
|
+
"""
|
|
275
|
+
Execute LLM with function calling loop.
|
|
276
|
+
|
|
277
|
+
The LLM can autonomously call tools over multiple iterations until
|
|
278
|
+
it has enough information to provide a final answer.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
prompt: User instruction/question
|
|
282
|
+
tools: List of tools LLM can call
|
|
283
|
+
max_iterations: Max number of turns (default 5)
|
|
284
|
+
**kwargs: Provider-specific options
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
{
|
|
288
|
+
"result": str, # Final answer from LLM
|
|
289
|
+
"tool_calls": [ # List of all tools called
|
|
290
|
+
{
|
|
291
|
+
"tool": str,
|
|
292
|
+
"arguments": dict,
|
|
293
|
+
"result": any
|
|
294
|
+
}
|
|
295
|
+
],
|
|
296
|
+
"iterations": int # Number of turns taken
|
|
297
|
+
}
|
|
298
|
+
"""
|
|
299
|
+
import json
|
|
300
|
+
|
|
301
|
+
conversation = [{"role": "user", "content": prompt}]
|
|
302
|
+
tools_called = []
|
|
303
|
+
|
|
304
|
+
for iteration in range(max_iterations):
|
|
305
|
+
# Convert tools to provider-specific format
|
|
306
|
+
tool_specs = self._convert_tools_to_format(tools)
|
|
307
|
+
|
|
308
|
+
# Call LLM with tools
|
|
309
|
+
response = await self._generate_with_tools_single(
|
|
310
|
+
messages=conversation,
|
|
311
|
+
tools=tool_specs,
|
|
312
|
+
**kwargs
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Check if LLM wants to call tools
|
|
316
|
+
if response.get("tool_calls"):
|
|
317
|
+
# Execute each tool call
|
|
318
|
+
for tool_call in response["tool_calls"]:
|
|
319
|
+
tool_result = await self._execute_tool_call(
|
|
320
|
+
tool_call=tool_call,
|
|
321
|
+
tools=tools
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
tools_called.append({
|
|
325
|
+
"tool": tool_call["name"],
|
|
326
|
+
"arguments": tool_call["arguments"],
|
|
327
|
+
"result": tool_result
|
|
328
|
+
})
|
|
329
|
+
|
|
330
|
+
# Add to conversation history
|
|
331
|
+
# Use universal flat format (provider-agnostic)
|
|
332
|
+
formatted_tool_call = {
|
|
333
|
+
"id": tool_call.get("id", str(len(tools_called))),
|
|
334
|
+
"name": tool_call["name"],
|
|
335
|
+
"arguments": tool_call["arguments"]
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
conversation.append({
|
|
339
|
+
"role": "assistant",
|
|
340
|
+
"tool_calls": [formatted_tool_call]
|
|
341
|
+
})
|
|
342
|
+
conversation.append({
|
|
343
|
+
"role": "tool",
|
|
344
|
+
"tool_call_id": tool_call.get("id", str(len(tools_called))),
|
|
345
|
+
"content": json.dumps(tool_result)
|
|
346
|
+
})
|
|
347
|
+
else:
|
|
348
|
+
# LLM returned final answer
|
|
349
|
+
return {
|
|
350
|
+
"result": response["content"],
|
|
351
|
+
"tool_calls": tools_called,
|
|
352
|
+
"iterations": iteration + 1
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
# Exceeded max iterations
|
|
356
|
+
return {
|
|
357
|
+
"result": f"Exceeded maximum iterations ({max_iterations}). Last response: {response.get('content', 'No response')}",
|
|
358
|
+
"tool_calls": tools_called,
|
|
359
|
+
"iterations": max_iterations,
|
|
360
|
+
"error": "max_iterations_exceeded"
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
def _convert_tools_to_format(self, tools: List['AgentTool']) -> List[Dict[str, Any]]:
|
|
364
|
+
"""
|
|
365
|
+
Convert AgentTool list to provider-specific format.
|
|
366
|
+
|
|
367
|
+
Default implementation uses OpenAI format. Providers can override
|
|
368
|
+
to use their own format (e.g., Anthropic).
|
|
369
|
+
"""
|
|
370
|
+
return [tool.to_openai_function() for tool in tools]
|
|
371
|
+
|
|
372
|
+
async def _execute_tool_call(
|
|
373
|
+
self,
|
|
374
|
+
tool_call: Dict[str, Any],
|
|
375
|
+
tools: List['AgentTool']
|
|
376
|
+
) -> Any:
|
|
377
|
+
"""Execute a single tool call with timeout and error handling."""
|
|
378
|
+
tool_name = tool_call["name"]
|
|
379
|
+
arguments = tool_call["arguments"]
|
|
380
|
+
|
|
381
|
+
# Find the tool
|
|
382
|
+
tool = next((t for t in tools if t.name == tool_name), None)
|
|
383
|
+
if not tool:
|
|
384
|
+
return {"error": f"Tool '{tool_name}' not found"}
|
|
385
|
+
|
|
386
|
+
# Execute with timeout
|
|
387
|
+
try:
|
|
388
|
+
result = await asyncio.wait_for(
|
|
389
|
+
tool.handler(arguments),
|
|
390
|
+
timeout=tool.timeout_seconds
|
|
391
|
+
)
|
|
392
|
+
return result
|
|
393
|
+
except asyncio.TimeoutError:
|
|
394
|
+
return {"error": f"Tool '{tool_name}' timed out after {tool.timeout_seconds}s"}
|
|
395
|
+
except Exception as e:
|
|
396
|
+
return {"error": f"Tool '{tool_name}' failed: {str(e)}"}
|
|
397
|
+
|
|
398
|
+
@abstractmethod
|
|
399
|
+
async def _generate_with_tools_single(
|
|
400
|
+
self,
|
|
401
|
+
messages: List[Dict[str, Any]],
|
|
402
|
+
tools: List[Dict[str, Any]],
|
|
403
|
+
**kwargs
|
|
404
|
+
) -> Dict[str, Any]:
|
|
405
|
+
"""
|
|
406
|
+
Single LLM call with tools (provider-specific).
|
|
407
|
+
|
|
408
|
+
This method must be implemented by each provider (OpenAI, Anthropic, etc.)
|
|
409
|
+
to handle their specific tool calling format.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
messages: Conversation history in OpenAI format
|
|
413
|
+
tools: Tool specifications in provider format
|
|
414
|
+
**kwargs: Optional parameters
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
{
|
|
418
|
+
"tool_calls": [...], # If LLM wants to call tools
|
|
419
|
+
"content": "...", # If LLM has final answer
|
|
420
|
+
}
|
|
421
|
+
"""
|
|
422
|
+
pass
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def info(self) -> Dict[str, Any]:
|
|
426
|
+
"""Get information about this LLM provider."""
|
|
427
|
+
return {
|
|
428
|
+
'provider': self.provider_name,
|
|
429
|
+
'model': self.model,
|
|
430
|
+
'agent_id': self.agent_id,
|
|
431
|
+
'config': {k: v for k, v in self.config.items() if 'key' not in k.lower()},
|
|
432
|
+
'default_params': self.default_params,
|
|
433
|
+
'tracing_enabled': True
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# Context manager for batch LLM operations
|
|
439
|
+
@asynccontextmanager
|
|
440
|
+
async def traced_llm_batch(llm_provider: BaseLLMProvider, batch_name: str = "llm_batch"):
|
|
441
|
+
"""
|
|
442
|
+
Context manager for tracing batch LLM operations.
|
|
443
|
+
|
|
444
|
+
Usage:
|
|
445
|
+
async with traced_llm_batch(llm, "document_analysis"):
|
|
446
|
+
summary = await llm.generate("Summarize: " + doc1)
|
|
447
|
+
analysis = await llm.generate("Analyze: " + doc2)
|
|
448
|
+
"""
|
|
449
|
+
trace_manager = get_trace_manager()
|
|
450
|
+
|
|
451
|
+
async with trace_manager.span(
|
|
452
|
+
operation_name=f"llm_batch_{batch_name}",
|
|
453
|
+
trace_type=TraceType.LLM_CALL,
|
|
454
|
+
agent_id=llm_provider.agent_id,
|
|
455
|
+
llm_provider=llm_provider.provider_name,
|
|
456
|
+
batch_operation=batch_name
|
|
457
|
+
):
|
|
458
|
+
try:
|
|
459
|
+
yield llm_provider
|
|
460
|
+
except Exception as e:
|
|
461
|
+
logger.error(f"LLM batch {batch_name} failed: {e}")
|
|
462
|
+
raise
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
# Utility functions
|
|
466
|
+
def get_llm_traces(agent_id: Optional[str] = None, provider: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
|
|
467
|
+
"""Get recent LLM call traces from unified tracing."""
|
|
468
|
+
trace_manager = get_trace_manager()
|
|
469
|
+
operations = trace_manager.get_recent_operations(agent_id=agent_id, limit=limit * 2)
|
|
470
|
+
|
|
471
|
+
# Filter for LLM operations
|
|
472
|
+
llm_ops = [
|
|
473
|
+
op for op in operations
|
|
474
|
+
if op.get('type') == 'llm_call'
|
|
475
|
+
]
|
|
476
|
+
|
|
477
|
+
# Filter by provider if specified
|
|
478
|
+
if provider:
|
|
479
|
+
llm_ops = [
|
|
480
|
+
op for op in llm_ops
|
|
481
|
+
if op.get('metadata', {}).get('llm_provider') == provider
|
|
482
|
+
]
|
|
483
|
+
|
|
484
|
+
return llm_ops[:limit]
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def get_llm_stats(agent_id: Optional[str] = None, provider: Optional[str] = None) -> Dict[str, Any]:
|
|
488
|
+
"""Get LLM usage statistics from unified tracing."""
|
|
489
|
+
traces = get_llm_traces(agent_id, provider, limit=50)
|
|
490
|
+
|
|
491
|
+
if not traces:
|
|
492
|
+
return {"total_calls": 0, "success_rate": 0, "avg_latency_ms": 0}
|
|
493
|
+
|
|
494
|
+
total_calls = len(traces)
|
|
495
|
+
successful_calls = len([t for t in traces if t.get('status') == 'success'])
|
|
496
|
+
|
|
497
|
+
# Calculate averages
|
|
498
|
+
latencies = [t.get('duration_ms', 0) for t in traces if t.get('duration_ms')]
|
|
499
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
|
500
|
+
|
|
501
|
+
# Token aggregation (simplified for MVP)
|
|
502
|
+
total_tokens = 0
|
|
503
|
+
for trace in traces:
|
|
504
|
+
metadata = trace.get('metadata', {})
|
|
505
|
+
total_tokens += metadata.get('tokens_total', 0)
|
|
506
|
+
|
|
507
|
+
return {
|
|
508
|
+
"total_calls": total_calls,
|
|
509
|
+
"successful_calls": successful_calls,
|
|
510
|
+
"failed_calls": total_calls - successful_calls,
|
|
511
|
+
"success_rate": successful_calls / total_calls if total_calls > 0 else 0,
|
|
512
|
+
"avg_latency_ms": avg_latency,
|
|
513
|
+
"total_tokens": total_tokens,
|
|
514
|
+
"agent_id": agent_id,
|
|
515
|
+
"provider": provider
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
# Export everything
|
|
520
|
+
__all__ = [
|
|
521
|
+
# Base class
|
|
522
|
+
"BaseLLMProvider",
|
|
523
|
+
|
|
524
|
+
# Context managers
|
|
525
|
+
"traced_llm_batch",
|
|
526
|
+
|
|
527
|
+
# Utility functions
|
|
528
|
+
"get_llm_traces",
|
|
529
|
+
"get_llm_stats"
|
|
530
|
+
]
|
daita/llm/factory.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Factory for creating LLM provider instances.
|
|
3
|
+
"""
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from ..core.exceptions import LLMError
|
|
8
|
+
from .base import BaseLLMProvider
|
|
9
|
+
from .openai import OpenAIProvider
|
|
10
|
+
from .anthropic import AnthropicProvider
|
|
11
|
+
from .grok import GrokProvider
|
|
12
|
+
from .gemini import GeminiProvider
|
|
13
|
+
from .mock import MockLLMProvider
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Registry of available providers
|
|
18
|
+
PROVIDER_REGISTRY = {
|
|
19
|
+
'openai': OpenAIProvider,
|
|
20
|
+
'anthropic': AnthropicProvider,
|
|
21
|
+
'grok': GrokProvider,
|
|
22
|
+
'gemini': GeminiProvider,
|
|
23
|
+
'mock': MockLLMProvider,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def create_llm_provider(
|
|
27
|
+
provider: str,
|
|
28
|
+
model: str,
|
|
29
|
+
api_key: Optional[str] = None,
|
|
30
|
+
agent_id: Optional[str] = None,
|
|
31
|
+
**kwargs
|
|
32
|
+
) -> BaseLLMProvider:
|
|
33
|
+
"""
|
|
34
|
+
Factory function to create LLM provider instances.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
provider: Provider name ('openai', 'anthropic', 'grok', 'gemini', 'mock')
|
|
38
|
+
model: Model identifier
|
|
39
|
+
api_key: API key for authentication
|
|
40
|
+
agent_id: Agent ID for token tracking
|
|
41
|
+
**kwargs: Additional provider-specific parameters
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
LLM provider instance
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
LLMError: If provider is not supported
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> # Create OpenAI provider with token tracking
|
|
51
|
+
>>> llm = create_llm_provider('openai', 'gpt-4', api_key='sk-...', agent_id='my_agent')
|
|
52
|
+
|
|
53
|
+
>>> # Create Anthropic provider with token tracking
|
|
54
|
+
>>> llm = create_llm_provider('anthropic', 'claude-3-sonnet-20240229', agent_id='my_agent')
|
|
55
|
+
|
|
56
|
+
>>> # Create Grok provider
|
|
57
|
+
>>> llm = create_llm_provider('grok', 'grok-beta', api_key='xai-...', agent_id='my_agent')
|
|
58
|
+
|
|
59
|
+
>>> # Create Gemini provider
|
|
60
|
+
>>> llm = create_llm_provider('gemini', 'gemini-1.5-pro', api_key='AIza...', agent_id='my_agent')
|
|
61
|
+
|
|
62
|
+
>>> # Create mock provider for testing
|
|
63
|
+
>>> llm = create_llm_provider('mock', 'test-model', agent_id='test_agent')
|
|
64
|
+
"""
|
|
65
|
+
provider_name = provider.lower()
|
|
66
|
+
|
|
67
|
+
if provider_name not in PROVIDER_REGISTRY:
|
|
68
|
+
available_providers = list(PROVIDER_REGISTRY.keys())
|
|
69
|
+
raise LLMError(
|
|
70
|
+
f"Unsupported LLM provider: {provider}. "
|
|
71
|
+
f"Available providers: {available_providers}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
provider_class = PROVIDER_REGISTRY[provider_name]
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# Pass agent_id to provider for token tracking
|
|
78
|
+
return provider_class(
|
|
79
|
+
model=model,
|
|
80
|
+
api_key=api_key,
|
|
81
|
+
agent_id=agent_id,
|
|
82
|
+
**kwargs
|
|
83
|
+
)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Failed to create {provider} provider: {str(e)}")
|
|
86
|
+
raise LLMError(f"Failed to create {provider} provider: {str(e)}")
|
|
87
|
+
|
|
88
|
+
def register_llm_provider(name: str, provider_class) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Register a custom LLM provider.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
name: Provider name
|
|
94
|
+
provider_class: Provider class that implements LLMProvider interface
|
|
95
|
+
"""
|
|
96
|
+
PROVIDER_REGISTRY[name.lower()] = provider_class
|
|
97
|
+
logger.info(f"Registered custom LLM provider: {name}")
|
|
98
|
+
|
|
99
|
+
def list_available_providers() -> list:
|
|
100
|
+
"""Get list of available LLM providers."""
|
|
101
|
+
return list(PROVIDER_REGISTRY.keys())
|