daita-agents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of daita-agents might be problematic. Click here for more details.
- daita/__init__.py +208 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +722 -0
- daita/agents/substrate.py +895 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +382 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +695 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +444 -0
- daita/core/tools.py +402 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1084 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +166 -0
- daita/llm/base.py +373 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +152 -0
- daita/llm/grok.py +114 -0
- daita/llm/mock.py +135 -0
- daita/llm/openai.py +109 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +844 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +510 -0
- daita/plugins/mysql.py +351 -0
- daita/plugins/postgresql.py +331 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +529 -0
- daita/plugins/s3.py +761 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.1.0.dist-info/METADATA +350 -0
- daita_agents-0.1.0.dist-info/RECORD +69 -0
- daita_agents-0.1.0.dist-info/WHEEL +5 -0
- daita_agents-0.1.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.1.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.1.0.dist-info/top_level.txt +1 -0
daita/llm/base.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Updated BaseLLMProvider with Unified Tracing Integration
|
|
3
|
+
|
|
4
|
+
This replaces the old BaseLLMProvider to use the unified tracing system.
|
|
5
|
+
All LLM calls are automatically traced without user configuration.
|
|
6
|
+
|
|
7
|
+
Key Changes:
|
|
8
|
+
- Removed old token tracking system completely
|
|
9
|
+
- Integrated automatic LLM call tracing
|
|
10
|
+
- Simple cost estimation
|
|
11
|
+
- Automatic provider/model/token capture
|
|
12
|
+
- Zero configuration required
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Dict, Any, Optional, List
|
|
17
|
+
from contextlib import asynccontextmanager
|
|
18
|
+
import logging
|
|
19
|
+
import time
|
|
20
|
+
import asyncio
|
|
21
|
+
|
|
22
|
+
from ..core.tracing import get_trace_manager, TraceType, TraceStatus
|
|
23
|
+
from ..core.interfaces import LLMProvider
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
class BaseLLMProvider(LLMProvider, ABC):
|
|
28
|
+
"""
|
|
29
|
+
Base class for LLM providers with automatic call tracing.
|
|
30
|
+
|
|
31
|
+
Every LLM call is automatically traced with:
|
|
32
|
+
- Provider and model details
|
|
33
|
+
- Token usage and costs
|
|
34
|
+
- Latency and performance
|
|
35
|
+
- Input/output content (preview)
|
|
36
|
+
- Error tracking
|
|
37
|
+
|
|
38
|
+
Users get full LLM observability without any configuration.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model: str,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
**kwargs
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize the LLM provider with automatic tracing.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: Model identifier
|
|
52
|
+
api_key: API key for authentication
|
|
53
|
+
**kwargs: Additional provider-specific options
|
|
54
|
+
"""
|
|
55
|
+
self.model = model
|
|
56
|
+
self.api_key = api_key
|
|
57
|
+
self.config = kwargs
|
|
58
|
+
|
|
59
|
+
# Default parameters
|
|
60
|
+
self.default_params = {
|
|
61
|
+
'temperature': kwargs.get('temperature', 0.7),
|
|
62
|
+
'max_tokens': kwargs.get('max_tokens', 1000),
|
|
63
|
+
'top_p': kwargs.get('top_p', 1.0),
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Agent ID for tracing (set by agent)
|
|
67
|
+
self.agent_id = kwargs.get('agent_id')
|
|
68
|
+
|
|
69
|
+
# Get trace manager for automatic tracing
|
|
70
|
+
self.trace_manager = get_trace_manager()
|
|
71
|
+
|
|
72
|
+
# Provider name for tracing
|
|
73
|
+
self.provider_name = self.__class__.__name__.replace('Provider', '').lower()
|
|
74
|
+
|
|
75
|
+
# Last usage for cost estimation
|
|
76
|
+
self._last_usage = None
|
|
77
|
+
|
|
78
|
+
logger.debug(f"Initialized {self.__class__.__name__} with model {model} (automatic tracing enabled)")
|
|
79
|
+
|
|
80
|
+
async def generate(self, prompt: str, **kwargs) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Generate text from prompt with automatic LLM call tracing.
|
|
83
|
+
|
|
84
|
+
Every call is automatically traced with full metadata.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
prompt: Input prompt
|
|
88
|
+
**kwargs: Optional parameters to override defaults
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Generated text response
|
|
92
|
+
"""
|
|
93
|
+
# Merge parameters
|
|
94
|
+
params = self._merge_params(kwargs)
|
|
95
|
+
|
|
96
|
+
# Automatically trace the LLM call
|
|
97
|
+
async with self.trace_manager.span(
|
|
98
|
+
operation_name=f"llm_{self.provider_name}_{self.model}",
|
|
99
|
+
trace_type=TraceType.LLM_CALL,
|
|
100
|
+
agent_id=self.agent_id,
|
|
101
|
+
input_data=prompt,
|
|
102
|
+
llm_provider=self.provider_name,
|
|
103
|
+
llm_model=self.model,
|
|
104
|
+
temperature=str(params.get('temperature', 0.7)),
|
|
105
|
+
max_tokens=str(params.get('max_tokens', 1000))
|
|
106
|
+
) as span_id:
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
# Execute the actual LLM call
|
|
112
|
+
response = await self._generate_impl(prompt, **params)
|
|
113
|
+
|
|
114
|
+
end_time = time.time()
|
|
115
|
+
duration_ms = (end_time - start_time) * 1000
|
|
116
|
+
|
|
117
|
+
# Get token usage from the call
|
|
118
|
+
token_usage = self._get_last_token_usage()
|
|
119
|
+
|
|
120
|
+
# Record LLM call details in the trace
|
|
121
|
+
self.trace_manager.record_llm_call(
|
|
122
|
+
span_id=span_id,
|
|
123
|
+
provider=self.provider_name,
|
|
124
|
+
model=self.model,
|
|
125
|
+
tokens=token_usage
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Add cost estimation if available
|
|
129
|
+
cost = self._estimate_cost(token_usage)
|
|
130
|
+
if cost and span_id in self.trace_manager._active_spans:
|
|
131
|
+
span = self.trace_manager._active_spans[span_id]
|
|
132
|
+
span.metadata['estimated_cost_usd'] = cost
|
|
133
|
+
span.metadata['duration_ms'] = duration_ms
|
|
134
|
+
|
|
135
|
+
logger.debug(f"LLM call completed: {token_usage.get('total_tokens', 0)} tokens in {duration_ms:.1f}ms")
|
|
136
|
+
return response
|
|
137
|
+
|
|
138
|
+
except Exception as e:
|
|
139
|
+
# LLM call failed - error automatically recorded by span context
|
|
140
|
+
logger.warning(f"LLM call failed: {str(e)}")
|
|
141
|
+
raise
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
async def _generate_impl(self, prompt: str, **kwargs) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Provider-specific implementation of text generation.
|
|
147
|
+
|
|
148
|
+
This method must be implemented by each provider (OpenAI, Anthropic, etc.)
|
|
149
|
+
and contains the actual LLM API call logic.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
prompt: Input prompt
|
|
153
|
+
**kwargs: Optional parameters
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Generated text response
|
|
157
|
+
"""
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
def _get_last_token_usage(self) -> Dict[str, int]:
|
|
161
|
+
"""
|
|
162
|
+
Get token usage from the last API call.
|
|
163
|
+
|
|
164
|
+
This should be overridden by each provider to return actual
|
|
165
|
+
token usage from their last API response.
|
|
166
|
+
"""
|
|
167
|
+
if self._last_usage:
|
|
168
|
+
# Try to extract from stored usage object
|
|
169
|
+
if hasattr(self._last_usage, 'total_tokens'):
|
|
170
|
+
# OpenAI format
|
|
171
|
+
return {
|
|
172
|
+
'total_tokens': getattr(self._last_usage, 'total_tokens', 0),
|
|
173
|
+
'prompt_tokens': getattr(self._last_usage, 'prompt_tokens', 0),
|
|
174
|
+
'completion_tokens': getattr(self._last_usage, 'completion_tokens', 0)
|
|
175
|
+
}
|
|
176
|
+
elif hasattr(self._last_usage, 'input_tokens'):
|
|
177
|
+
# Anthropic format
|
|
178
|
+
input_tokens = getattr(self._last_usage, 'input_tokens', 0)
|
|
179
|
+
output_tokens = getattr(self._last_usage, 'output_tokens', 0)
|
|
180
|
+
return {
|
|
181
|
+
'total_tokens': input_tokens + output_tokens,
|
|
182
|
+
'prompt_tokens': input_tokens,
|
|
183
|
+
'completion_tokens': output_tokens
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
# Default fallback
|
|
187
|
+
return {
|
|
188
|
+
'total_tokens': 0,
|
|
189
|
+
'prompt_tokens': 0,
|
|
190
|
+
'completion_tokens': 0
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def _estimate_cost(self, token_usage: Dict[str, int]) -> Optional[float]:
|
|
194
|
+
"""
|
|
195
|
+
Simple cost estimation for MVP.
|
|
196
|
+
|
|
197
|
+
Providers can override with specific pricing.
|
|
198
|
+
"""
|
|
199
|
+
total_tokens = token_usage.get('total_tokens', 0)
|
|
200
|
+
if total_tokens == 0:
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Generic estimation for MVP - providers should override
|
|
204
|
+
cost_per_1k_tokens = 0.002 # $0.002 per 1K tokens
|
|
205
|
+
return (total_tokens / 1000) * cost_per_1k_tokens
|
|
206
|
+
|
|
207
|
+
def _merge_params(self, override_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
208
|
+
"""Merge default parameters with overrides."""
|
|
209
|
+
params = self.default_params.copy()
|
|
210
|
+
params.update(override_params)
|
|
211
|
+
return params
|
|
212
|
+
|
|
213
|
+
def _validate_api_key(self) -> None:
|
|
214
|
+
"""Validate that API key is available."""
|
|
215
|
+
if not self.api_key:
|
|
216
|
+
raise ValueError(f"API key required for {self.__class__.__name__}")
|
|
217
|
+
|
|
218
|
+
def set_agent_id(self, agent_id: str):
|
|
219
|
+
"""
|
|
220
|
+
Set the agent ID for tracing context.
|
|
221
|
+
|
|
222
|
+
This is called automatically by BaseAgent during initialization.
|
|
223
|
+
"""
|
|
224
|
+
self.agent_id = agent_id
|
|
225
|
+
logger.debug(f"Set agent ID {agent_id} for {self.provider_name} provider")
|
|
226
|
+
|
|
227
|
+
def get_recent_calls(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
228
|
+
"""Get recent LLM calls for this provider's agent from unified tracing."""
|
|
229
|
+
if not self.agent_id:
|
|
230
|
+
return []
|
|
231
|
+
|
|
232
|
+
operations = self.trace_manager.get_recent_operations(agent_id=self.agent_id, limit=limit * 2)
|
|
233
|
+
|
|
234
|
+
# Filter for LLM calls from this provider
|
|
235
|
+
llm_calls = [
|
|
236
|
+
op for op in operations
|
|
237
|
+
if (op.get('type') == 'llm_call' and
|
|
238
|
+
op.get('metadata', {}).get('llm_provider') == self.provider_name)
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
return llm_calls[:limit]
|
|
242
|
+
|
|
243
|
+
def get_token_stats(self) -> Dict[str, Any]:
|
|
244
|
+
"""Get token usage statistics from unified tracing."""
|
|
245
|
+
if not self.agent_id:
|
|
246
|
+
return {
|
|
247
|
+
'total_calls': 0,
|
|
248
|
+
'total_tokens': 0,
|
|
249
|
+
'estimated_cost': 0.0
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
metrics = self.trace_manager.get_agent_metrics(self.agent_id)
|
|
253
|
+
|
|
254
|
+
# Get actual token usage from the most recent call
|
|
255
|
+
last_usage = self._get_last_token_usage()
|
|
256
|
+
|
|
257
|
+
return {
|
|
258
|
+
'total_calls': metrics.get('total_operations', 0), # All operations
|
|
259
|
+
'total_tokens': last_usage.get('total_tokens', 0), # From last API call
|
|
260
|
+
'prompt_tokens': last_usage.get('prompt_tokens', 0),
|
|
261
|
+
'completion_tokens': last_usage.get('completion_tokens', 0),
|
|
262
|
+
'estimated_cost': self._estimate_cost(last_usage) or 0.0,
|
|
263
|
+
'success_rate': metrics.get('success_rate', 0),
|
|
264
|
+
'avg_latency_ms': metrics.get('avg_latency_ms', 0)
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def info(self) -> Dict[str, Any]:
|
|
269
|
+
"""Get information about this LLM provider."""
|
|
270
|
+
return {
|
|
271
|
+
'provider': self.provider_name,
|
|
272
|
+
'model': self.model,
|
|
273
|
+
'agent_id': self.agent_id,
|
|
274
|
+
'config': {k: v for k, v in self.config.items() if 'key' not in k.lower()},
|
|
275
|
+
'default_params': self.default_params,
|
|
276
|
+
'tracing_enabled': True
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# Context manager for batch LLM operations
|
|
282
|
+
@asynccontextmanager
|
|
283
|
+
async def traced_llm_batch(llm_provider: BaseLLMProvider, batch_name: str = "llm_batch"):
|
|
284
|
+
"""
|
|
285
|
+
Context manager for tracing batch LLM operations.
|
|
286
|
+
|
|
287
|
+
Usage:
|
|
288
|
+
async with traced_llm_batch(llm, "document_analysis"):
|
|
289
|
+
summary = await llm.generate("Summarize: " + doc1)
|
|
290
|
+
analysis = await llm.generate("Analyze: " + doc2)
|
|
291
|
+
"""
|
|
292
|
+
trace_manager = get_trace_manager()
|
|
293
|
+
|
|
294
|
+
async with trace_manager.span(
|
|
295
|
+
operation_name=f"llm_batch_{batch_name}",
|
|
296
|
+
trace_type=TraceType.LLM_CALL,
|
|
297
|
+
agent_id=llm_provider.agent_id,
|
|
298
|
+
llm_provider=llm_provider.provider_name,
|
|
299
|
+
batch_operation=batch_name
|
|
300
|
+
):
|
|
301
|
+
try:
|
|
302
|
+
yield llm_provider
|
|
303
|
+
except Exception as e:
|
|
304
|
+
logger.error(f"LLM batch {batch_name} failed: {e}")
|
|
305
|
+
raise
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# Utility functions
|
|
309
|
+
def get_llm_traces(agent_id: Optional[str] = None, provider: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
|
|
310
|
+
"""Get recent LLM call traces from unified tracing."""
|
|
311
|
+
trace_manager = get_trace_manager()
|
|
312
|
+
operations = trace_manager.get_recent_operations(agent_id=agent_id, limit=limit * 2)
|
|
313
|
+
|
|
314
|
+
# Filter for LLM operations
|
|
315
|
+
llm_ops = [
|
|
316
|
+
op for op in operations
|
|
317
|
+
if op.get('type') == 'llm_call'
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
# Filter by provider if specified
|
|
321
|
+
if provider:
|
|
322
|
+
llm_ops = [
|
|
323
|
+
op for op in llm_ops
|
|
324
|
+
if op.get('metadata', {}).get('llm_provider') == provider
|
|
325
|
+
]
|
|
326
|
+
|
|
327
|
+
return llm_ops[:limit]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def get_llm_stats(agent_id: Optional[str] = None, provider: Optional[str] = None) -> Dict[str, Any]:
|
|
331
|
+
"""Get LLM usage statistics from unified tracing."""
|
|
332
|
+
traces = get_llm_traces(agent_id, provider, limit=50)
|
|
333
|
+
|
|
334
|
+
if not traces:
|
|
335
|
+
return {"total_calls": 0, "success_rate": 0, "avg_latency_ms": 0}
|
|
336
|
+
|
|
337
|
+
total_calls = len(traces)
|
|
338
|
+
successful_calls = len([t for t in traces if t.get('status') == 'success'])
|
|
339
|
+
|
|
340
|
+
# Calculate averages
|
|
341
|
+
latencies = [t.get('duration_ms', 0) for t in traces if t.get('duration_ms')]
|
|
342
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
|
343
|
+
|
|
344
|
+
# Token aggregation (simplified for MVP)
|
|
345
|
+
total_tokens = 0
|
|
346
|
+
for trace in traces:
|
|
347
|
+
metadata = trace.get('metadata', {})
|
|
348
|
+
total_tokens += metadata.get('tokens_total', 0)
|
|
349
|
+
|
|
350
|
+
return {
|
|
351
|
+
"total_calls": total_calls,
|
|
352
|
+
"successful_calls": successful_calls,
|
|
353
|
+
"failed_calls": total_calls - successful_calls,
|
|
354
|
+
"success_rate": successful_calls / total_calls if total_calls > 0 else 0,
|
|
355
|
+
"avg_latency_ms": avg_latency,
|
|
356
|
+
"total_tokens": total_tokens,
|
|
357
|
+
"agent_id": agent_id,
|
|
358
|
+
"provider": provider
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# Export everything
|
|
363
|
+
__all__ = [
|
|
364
|
+
# Base class
|
|
365
|
+
"BaseLLMProvider",
|
|
366
|
+
|
|
367
|
+
# Context managers
|
|
368
|
+
"traced_llm_batch",
|
|
369
|
+
|
|
370
|
+
# Utility functions
|
|
371
|
+
"get_llm_traces",
|
|
372
|
+
"get_llm_stats"
|
|
373
|
+
]
|
daita/llm/factory.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Factory for creating LLM provider instances.
|
|
3
|
+
"""
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from ..core.exceptions import LLMError
|
|
8
|
+
from .base import BaseLLMProvider
|
|
9
|
+
from .openai import OpenAIProvider
|
|
10
|
+
from .anthropic import AnthropicProvider
|
|
11
|
+
from .grok import GrokProvider
|
|
12
|
+
from .gemini import GeminiProvider
|
|
13
|
+
from .mock import MockLLMProvider
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Registry of available providers
|
|
18
|
+
PROVIDER_REGISTRY = {
|
|
19
|
+
'openai': OpenAIProvider,
|
|
20
|
+
'anthropic': AnthropicProvider,
|
|
21
|
+
'grok': GrokProvider,
|
|
22
|
+
'gemini': GeminiProvider,
|
|
23
|
+
'mock': MockLLMProvider,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def create_llm_provider(
|
|
27
|
+
provider: str,
|
|
28
|
+
model: str,
|
|
29
|
+
api_key: Optional[str] = None,
|
|
30
|
+
agent_id: Optional[str] = None,
|
|
31
|
+
**kwargs
|
|
32
|
+
) -> BaseLLMProvider:
|
|
33
|
+
"""
|
|
34
|
+
Factory function to create LLM provider instances.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
provider: Provider name ('openai', 'anthropic', 'grok', 'gemini', 'mock')
|
|
38
|
+
model: Model identifier
|
|
39
|
+
api_key: API key for authentication
|
|
40
|
+
agent_id: Agent ID for token tracking
|
|
41
|
+
**kwargs: Additional provider-specific parameters
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
LLM provider instance
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
LLMError: If provider is not supported
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> # Create OpenAI provider with token tracking
|
|
51
|
+
>>> llm = create_llm_provider('openai', 'gpt-4', api_key='sk-...', agent_id='my_agent')
|
|
52
|
+
|
|
53
|
+
>>> # Create Anthropic provider with token tracking
|
|
54
|
+
>>> llm = create_llm_provider('anthropic', 'claude-3-sonnet-20240229', agent_id='my_agent')
|
|
55
|
+
|
|
56
|
+
>>> # Create Grok provider
|
|
57
|
+
>>> llm = create_llm_provider('grok', 'grok-beta', api_key='xai-...', agent_id='my_agent')
|
|
58
|
+
|
|
59
|
+
>>> # Create Gemini provider
|
|
60
|
+
>>> llm = create_llm_provider('gemini', 'gemini-1.5-pro', api_key='AIza...', agent_id='my_agent')
|
|
61
|
+
|
|
62
|
+
>>> # Create mock provider for testing
|
|
63
|
+
>>> llm = create_llm_provider('mock', 'test-model', agent_id='test_agent')
|
|
64
|
+
"""
|
|
65
|
+
provider_name = provider.lower()
|
|
66
|
+
|
|
67
|
+
if provider_name not in PROVIDER_REGISTRY:
|
|
68
|
+
available_providers = list(PROVIDER_REGISTRY.keys())
|
|
69
|
+
raise LLMError(
|
|
70
|
+
f"Unsupported LLM provider: {provider}. "
|
|
71
|
+
f"Available providers: {available_providers}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
provider_class = PROVIDER_REGISTRY[provider_name]
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# Pass agent_id to provider for token tracking
|
|
78
|
+
return provider_class(
|
|
79
|
+
model=model,
|
|
80
|
+
api_key=api_key,
|
|
81
|
+
agent_id=agent_id,
|
|
82
|
+
**kwargs
|
|
83
|
+
)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Failed to create {provider} provider: {str(e)}")
|
|
86
|
+
raise LLMError(f"Failed to create {provider} provider: {str(e)}")
|
|
87
|
+
|
|
88
|
+
def register_llm_provider(name: str, provider_class) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Register a custom LLM provider.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
name: Provider name
|
|
94
|
+
provider_class: Provider class that implements LLMProvider interface
|
|
95
|
+
"""
|
|
96
|
+
PROVIDER_REGISTRY[name.lower()] = provider_class
|
|
97
|
+
logger.info(f"Registered custom LLM provider: {name}")
|
|
98
|
+
|
|
99
|
+
def list_available_providers() -> list:
|
|
100
|
+
"""Get list of available LLM providers."""
|
|
101
|
+
return list(PROVIDER_REGISTRY.keys())
|
daita/llm/gemini.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Gemini LLM provider implementation with integrated tracing.
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
import asyncio
|
|
7
|
+
from typing import Dict, Any, Optional
|
|
8
|
+
|
|
9
|
+
from ..core.exceptions import LLMError
|
|
10
|
+
from .base import BaseLLMProvider
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class GeminiProvider(BaseLLMProvider):
|
|
15
|
+
"""Google Gemini LLM provider implementation with automatic call tracing."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model: str = "gemini-1.5-flash",
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
**kwargs
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize Gemini provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Gemini model name (e.g., "gemini-1.5-flash", "gemini-1.5-pro", "gemini-1.0-pro")
|
|
28
|
+
api_key: Google AI API key
|
|
29
|
+
**kwargs: Additional Gemini-specific parameters
|
|
30
|
+
"""
|
|
31
|
+
# Get API key from parameter or environment
|
|
32
|
+
api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
|
33
|
+
|
|
34
|
+
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
35
|
+
|
|
36
|
+
# Gemini-specific default parameters
|
|
37
|
+
self.default_params.update({
|
|
38
|
+
'timeout': kwargs.get('timeout', 60),
|
|
39
|
+
'safety_settings': kwargs.get('safety_settings', None),
|
|
40
|
+
'generation_config': kwargs.get('generation_config', None)
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
# Lazy-load Gemini client
|
|
44
|
+
self._client = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def client(self):
|
|
48
|
+
"""Lazy-load Google Generative AI client."""
|
|
49
|
+
if self._client is None:
|
|
50
|
+
try:
|
|
51
|
+
import google.generativeai as genai
|
|
52
|
+
self._validate_api_key()
|
|
53
|
+
|
|
54
|
+
# Configure the API key
|
|
55
|
+
genai.configure(api_key=self.api_key)
|
|
56
|
+
|
|
57
|
+
# Create the generative model
|
|
58
|
+
self._client = genai.GenerativeModel(self.model)
|
|
59
|
+
logger.debug("Gemini client initialized")
|
|
60
|
+
except ImportError:
|
|
61
|
+
raise LLMError(
|
|
62
|
+
"Google Generative AI package not installed. Install with: pip install google-generativeai"
|
|
63
|
+
)
|
|
64
|
+
return self._client
|
|
65
|
+
|
|
66
|
+
async def _generate_impl(self, prompt: str, **kwargs) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Provider-specific implementation of text generation for Gemini.
|
|
69
|
+
|
|
70
|
+
This method contains the actual Gemini API call logic and is automatically
|
|
71
|
+
wrapped with tracing by the base class generate() method.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
prompt: Input prompt
|
|
75
|
+
**kwargs: Optional parameters
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Generated text response
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
# Merge parameters
|
|
82
|
+
params = self._merge_params(kwargs)
|
|
83
|
+
|
|
84
|
+
# Prepare generation config
|
|
85
|
+
generation_config = params.get('generation_config', {})
|
|
86
|
+
if not generation_config:
|
|
87
|
+
generation_config = {
|
|
88
|
+
'max_output_tokens': params.get('max_tokens'),
|
|
89
|
+
'temperature': params.get('temperature'),
|
|
90
|
+
'top_p': params.get('top_p')
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Make API call (Gemini's generate_content can be sync or async)
|
|
94
|
+
# For consistency with other providers, we'll run in executor if needed
|
|
95
|
+
if asyncio.iscoroutinefunction(self.client.generate_content):
|
|
96
|
+
response = await self.client.generate_content(
|
|
97
|
+
prompt,
|
|
98
|
+
generation_config=generation_config,
|
|
99
|
+
safety_settings=params.get('safety_settings')
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Run synchronous method in executor
|
|
103
|
+
loop = asyncio.get_event_loop()
|
|
104
|
+
response = await loop.run_in_executor(
|
|
105
|
+
None,
|
|
106
|
+
lambda: self.client.generate_content(
|
|
107
|
+
prompt,
|
|
108
|
+
generation_config=generation_config,
|
|
109
|
+
safety_settings=params.get('safety_settings')
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Store usage info if available (Gemini's usage tracking varies)
|
|
114
|
+
if hasattr(response, 'usage_metadata'):
|
|
115
|
+
self._last_usage = response.usage_metadata
|
|
116
|
+
|
|
117
|
+
return response.text
|
|
118
|
+
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.error(f"Gemini generation failed: {str(e)}")
|
|
121
|
+
raise LLMError(f"Gemini generation failed: {str(e)}")
|
|
122
|
+
|
|
123
|
+
def _get_last_token_usage(self) -> Dict[str, int]:
|
|
124
|
+
"""
|
|
125
|
+
Override base class method to handle Gemini's token format.
|
|
126
|
+
|
|
127
|
+
Gemini uses different token field names in usage_metadata.
|
|
128
|
+
"""
|
|
129
|
+
if self._last_usage:
|
|
130
|
+
# Gemini format varies, try to extract what we can
|
|
131
|
+
prompt_tokens = getattr(self._last_usage, 'prompt_token_count', 0)
|
|
132
|
+
completion_tokens = getattr(self._last_usage, 'candidates_token_count', 0)
|
|
133
|
+
total_tokens = getattr(self._last_usage, 'total_token_count', prompt_tokens + completion_tokens)
|
|
134
|
+
|
|
135
|
+
return {
|
|
136
|
+
'total_tokens': total_tokens,
|
|
137
|
+
'prompt_tokens': prompt_tokens,
|
|
138
|
+
'completion_tokens': completion_tokens
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
# Fallback to base class estimation
|
|
142
|
+
return super()._get_last_token_usage()
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def info(self) -> Dict[str, Any]:
|
|
146
|
+
"""Get information about the Gemini provider."""
|
|
147
|
+
base_info = super().info
|
|
148
|
+
base_info.update({
|
|
149
|
+
'provider_name': 'Google Gemini',
|
|
150
|
+
'api_compatible': 'Google AI'
|
|
151
|
+
})
|
|
152
|
+
return base_info
|