daita-agents 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. daita/__init__.py +216 -0
  2. daita/agents/__init__.py +33 -0
  3. daita/agents/base.py +743 -0
  4. daita/agents/substrate.py +1141 -0
  5. daita/cli/__init__.py +145 -0
  6. daita/cli/__main__.py +7 -0
  7. daita/cli/ascii_art.py +44 -0
  8. daita/cli/core/__init__.py +0 -0
  9. daita/cli/core/create.py +254 -0
  10. daita/cli/core/deploy.py +473 -0
  11. daita/cli/core/deployments.py +309 -0
  12. daita/cli/core/import_detector.py +219 -0
  13. daita/cli/core/init.py +481 -0
  14. daita/cli/core/logs.py +239 -0
  15. daita/cli/core/managed_deploy.py +709 -0
  16. daita/cli/core/run.py +648 -0
  17. daita/cli/core/status.py +421 -0
  18. daita/cli/core/test.py +239 -0
  19. daita/cli/core/webhooks.py +172 -0
  20. daita/cli/main.py +588 -0
  21. daita/cli/utils.py +541 -0
  22. daita/config/__init__.py +62 -0
  23. daita/config/base.py +159 -0
  24. daita/config/settings.py +184 -0
  25. daita/core/__init__.py +262 -0
  26. daita/core/decision_tracing.py +701 -0
  27. daita/core/exceptions.py +480 -0
  28. daita/core/focus.py +251 -0
  29. daita/core/interfaces.py +76 -0
  30. daita/core/plugin_tracing.py +550 -0
  31. daita/core/relay.py +779 -0
  32. daita/core/reliability.py +381 -0
  33. daita/core/scaling.py +459 -0
  34. daita/core/tools.py +554 -0
  35. daita/core/tracing.py +770 -0
  36. daita/core/workflow.py +1144 -0
  37. daita/display/__init__.py +1 -0
  38. daita/display/console.py +160 -0
  39. daita/execution/__init__.py +58 -0
  40. daita/execution/client.py +856 -0
  41. daita/execution/exceptions.py +92 -0
  42. daita/execution/models.py +317 -0
  43. daita/llm/__init__.py +60 -0
  44. daita/llm/anthropic.py +291 -0
  45. daita/llm/base.py +530 -0
  46. daita/llm/factory.py +101 -0
  47. daita/llm/gemini.py +355 -0
  48. daita/llm/grok.py +219 -0
  49. daita/llm/mock.py +172 -0
  50. daita/llm/openai.py +220 -0
  51. daita/plugins/__init__.py +141 -0
  52. daita/plugins/base.py +37 -0
  53. daita/plugins/base_db.py +167 -0
  54. daita/plugins/elasticsearch.py +849 -0
  55. daita/plugins/mcp.py +481 -0
  56. daita/plugins/mongodb.py +520 -0
  57. daita/plugins/mysql.py +362 -0
  58. daita/plugins/postgresql.py +342 -0
  59. daita/plugins/redis_messaging.py +500 -0
  60. daita/plugins/rest.py +537 -0
  61. daita/plugins/s3.py +770 -0
  62. daita/plugins/slack.py +729 -0
  63. daita/utils/__init__.py +18 -0
  64. daita_agents-0.2.0.dist-info/METADATA +409 -0
  65. daita_agents-0.2.0.dist-info/RECORD +69 -0
  66. daita_agents-0.2.0.dist-info/WHEEL +5 -0
  67. daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
  68. daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
  69. daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/llm/base.py ADDED
@@ -0,0 +1,530 @@
1
+ """
2
+ Updated BaseLLMProvider with Unified Tracing Integration
3
+
4
+ This replaces the old BaseLLMProvider to use the unified tracing system.
5
+ All LLM calls are automatically traced without user configuration.
6
+
7
+ Key Changes:
8
+ - Removed old token tracking system completely
9
+ - Integrated automatic LLM call tracing
10
+ - Simple cost estimation
11
+ - Automatic provider/model/token capture
12
+ - Zero configuration required
13
+ """
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Dict, Any, Optional, List
17
+ from contextlib import asynccontextmanager
18
+ import logging
19
+ import time
20
+ import asyncio
21
+
22
+ from ..core.tracing import get_trace_manager, TraceType, TraceStatus
23
+ from ..core.interfaces import LLMProvider
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class BaseLLMProvider(LLMProvider, ABC):
28
+ """
29
+ Base class for LLM providers with automatic call tracing.
30
+
31
+ Every LLM call is automatically traced with:
32
+ - Provider and model details
33
+ - Token usage and costs
34
+ - Latency and performance
35
+ - Input/output content (preview)
36
+ - Error tracking
37
+
38
+ Users get full LLM observability without any configuration.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model: str,
44
+ api_key: Optional[str] = None,
45
+ **kwargs
46
+ ):
47
+ """
48
+ Initialize the LLM provider with automatic tracing.
49
+
50
+ Args:
51
+ model: Model identifier
52
+ api_key: API key for authentication
53
+ **kwargs: Additional provider-specific options
54
+ """
55
+ self.model = model
56
+ self.api_key = api_key
57
+ self.config = kwargs
58
+
59
+ # Default parameters
60
+ self.default_params = {
61
+ 'temperature': kwargs.get('temperature', 0.7),
62
+ 'max_tokens': kwargs.get('max_tokens', 1000),
63
+ 'top_p': kwargs.get('top_p', 1.0),
64
+ }
65
+
66
+ # Agent ID for tracing (set by agent)
67
+ self.agent_id = kwargs.get('agent_id')
68
+
69
+ # Get trace manager for automatic tracing
70
+ self.trace_manager = get_trace_manager()
71
+
72
+ # Provider name for tracing
73
+ self.provider_name = self.__class__.__name__.replace('Provider', '').lower()
74
+
75
+ # Last usage for cost estimation
76
+ self._last_usage = None
77
+
78
+ logger.debug(f"Initialized {self.__class__.__name__} with model {model} (automatic tracing enabled)")
79
+
80
+ async def generate(self, prompt: str, **kwargs) -> str:
81
+ """
82
+ Generate text from prompt with automatic LLM call tracing.
83
+
84
+ Every call is automatically traced with full metadata.
85
+
86
+ Args:
87
+ prompt: Input prompt
88
+ **kwargs: Optional parameters to override defaults
89
+
90
+ Returns:
91
+ Generated text response
92
+ """
93
+ # Merge parameters
94
+ params = self._merge_params(kwargs)
95
+
96
+ # Automatically trace the LLM call
97
+ async with self.trace_manager.span(
98
+ operation_name=f"llm_{self.provider_name}_{self.model}",
99
+ trace_type=TraceType.LLM_CALL,
100
+ agent_id=self.agent_id,
101
+ input_data=prompt,
102
+ llm_provider=self.provider_name,
103
+ llm_model=self.model,
104
+ temperature=str(params.get('temperature', 0.7)),
105
+ max_tokens=str(params.get('max_tokens', 1000))
106
+ ) as span_id:
107
+
108
+ try:
109
+ start_time = time.time()
110
+
111
+ # Execute the actual LLM call
112
+ response = await self._generate_impl(prompt, **params)
113
+
114
+ end_time = time.time()
115
+ duration_ms = (end_time - start_time) * 1000
116
+
117
+ # Get token usage from the call
118
+ token_usage = self._get_last_token_usage()
119
+
120
+ # Record LLM call details in the trace
121
+ self.trace_manager.record_llm_call(
122
+ span_id=span_id,
123
+ provider=self.provider_name,
124
+ model=self.model,
125
+ tokens=token_usage
126
+ )
127
+
128
+ # Add cost estimation if available
129
+ cost = self._estimate_cost(token_usage)
130
+ if cost and span_id in self.trace_manager._active_spans:
131
+ span = self.trace_manager._active_spans[span_id]
132
+ span.metadata['estimated_cost_usd'] = cost
133
+ span.metadata['duration_ms'] = duration_ms
134
+
135
+ logger.debug(f"LLM call completed: {token_usage.get('total_tokens', 0)} tokens in {duration_ms:.1f}ms")
136
+ return response
137
+
138
+ except Exception as e:
139
+ # LLM call failed - error automatically recorded by span context
140
+ logger.warning(f"LLM call failed: {str(e)}")
141
+ raise
142
+
143
+ @abstractmethod
144
+ async def _generate_impl(self, prompt: str, **kwargs) -> str:
145
+ """
146
+ Provider-specific implementation of text generation.
147
+
148
+ This method must be implemented by each provider (OpenAI, Anthropic, etc.)
149
+ and contains the actual LLM API call logic.
150
+
151
+ Args:
152
+ prompt: Input prompt
153
+ **kwargs: Optional parameters
154
+
155
+ Returns:
156
+ Generated text response
157
+ """
158
+ pass
159
+
160
+ def _get_last_token_usage(self) -> Dict[str, int]:
161
+ """
162
+ Get token usage from the last API call.
163
+
164
+ This should be overridden by each provider to return actual
165
+ token usage from their last API response.
166
+ """
167
+ if self._last_usage:
168
+ # Try to extract from stored usage object
169
+ if hasattr(self._last_usage, 'total_tokens'):
170
+ # OpenAI format
171
+ return {
172
+ 'total_tokens': getattr(self._last_usage, 'total_tokens', 0),
173
+ 'prompt_tokens': getattr(self._last_usage, 'prompt_tokens', 0),
174
+ 'completion_tokens': getattr(self._last_usage, 'completion_tokens', 0)
175
+ }
176
+ elif hasattr(self._last_usage, 'input_tokens'):
177
+ # Anthropic format
178
+ input_tokens = getattr(self._last_usage, 'input_tokens', 0)
179
+ output_tokens = getattr(self._last_usage, 'output_tokens', 0)
180
+ return {
181
+ 'total_tokens': input_tokens + output_tokens,
182
+ 'prompt_tokens': input_tokens,
183
+ 'completion_tokens': output_tokens
184
+ }
185
+
186
+ # Default fallback
187
+ return {
188
+ 'total_tokens': 0,
189
+ 'prompt_tokens': 0,
190
+ 'completion_tokens': 0
191
+ }
192
+
193
+ def _estimate_cost(self, token_usage: Dict[str, int]) -> Optional[float]:
194
+ """
195
+ Simple cost estimation for MVP.
196
+
197
+ Providers can override with specific pricing.
198
+ """
199
+ total_tokens = token_usage.get('total_tokens', 0)
200
+ if total_tokens == 0:
201
+ return None
202
+
203
+ # Generic estimation for MVP - providers should override
204
+ cost_per_1k_tokens = 0.002 # $0.002 per 1K tokens
205
+ return (total_tokens / 1000) * cost_per_1k_tokens
206
+
207
+ def _merge_params(self, override_params: Dict[str, Any]) -> Dict[str, Any]:
208
+ """Merge default parameters with overrides."""
209
+ params = self.default_params.copy()
210
+ params.update(override_params)
211
+ return params
212
+
213
+ def _validate_api_key(self) -> None:
214
+ """Validate that API key is available."""
215
+ if not self.api_key:
216
+ raise ValueError(f"API key required for {self.__class__.__name__}")
217
+
218
+ def set_agent_id(self, agent_id: str):
219
+ """
220
+ Set the agent ID for tracing context.
221
+
222
+ This is called automatically by BaseAgent during initialization.
223
+ """
224
+ self.agent_id = agent_id
225
+ logger.debug(f"Set agent ID {agent_id} for {self.provider_name} provider")
226
+
227
+ def get_recent_calls(self, limit: int = 10) -> List[Dict[str, Any]]:
228
+ """Get recent LLM calls for this provider's agent from unified tracing."""
229
+ if not self.agent_id:
230
+ return []
231
+
232
+ operations = self.trace_manager.get_recent_operations(agent_id=self.agent_id, limit=limit * 2)
233
+
234
+ # Filter for LLM calls from this provider
235
+ llm_calls = [
236
+ op for op in operations
237
+ if (op.get('type') == 'llm_call' and
238
+ op.get('metadata', {}).get('llm_provider') == self.provider_name)
239
+ ]
240
+
241
+ return llm_calls[:limit]
242
+
243
+ def get_token_stats(self) -> Dict[str, Any]:
244
+ """Get token usage statistics from unified tracing."""
245
+ if not self.agent_id:
246
+ return {
247
+ 'total_calls': 0,
248
+ 'total_tokens': 0,
249
+ 'estimated_cost': 0.0
250
+ }
251
+
252
+ metrics = self.trace_manager.get_agent_metrics(self.agent_id)
253
+
254
+ # Get actual token usage from the most recent call
255
+ last_usage = self._get_last_token_usage()
256
+
257
+ return {
258
+ 'total_calls': metrics.get('total_operations', 0), # All operations
259
+ 'total_tokens': last_usage.get('total_tokens', 0), # From last API call
260
+ 'prompt_tokens': last_usage.get('prompt_tokens', 0),
261
+ 'completion_tokens': last_usage.get('completion_tokens', 0),
262
+ 'estimated_cost': self._estimate_cost(last_usage) or 0.0,
263
+ 'success_rate': metrics.get('success_rate', 0),
264
+ 'avg_latency_ms': metrics.get('avg_latency_ms', 0)
265
+ }
266
+
267
+ async def generate_with_tools(
268
+ self,
269
+ prompt: str,
270
+ tools: List['AgentTool'],
271
+ max_iterations: int = 5,
272
+ **kwargs
273
+ ) -> Dict[str, Any]:
274
+ """
275
+ Execute LLM with function calling loop.
276
+
277
+ The LLM can autonomously call tools over multiple iterations until
278
+ it has enough information to provide a final answer.
279
+
280
+ Args:
281
+ prompt: User instruction/question
282
+ tools: List of tools LLM can call
283
+ max_iterations: Max number of turns (default 5)
284
+ **kwargs: Provider-specific options
285
+
286
+ Returns:
287
+ {
288
+ "result": str, # Final answer from LLM
289
+ "tool_calls": [ # List of all tools called
290
+ {
291
+ "tool": str,
292
+ "arguments": dict,
293
+ "result": any
294
+ }
295
+ ],
296
+ "iterations": int # Number of turns taken
297
+ }
298
+ """
299
+ import json
300
+
301
+ conversation = [{"role": "user", "content": prompt}]
302
+ tools_called = []
303
+
304
+ for iteration in range(max_iterations):
305
+ # Convert tools to provider-specific format
306
+ tool_specs = self._convert_tools_to_format(tools)
307
+
308
+ # Call LLM with tools
309
+ response = await self._generate_with_tools_single(
310
+ messages=conversation,
311
+ tools=tool_specs,
312
+ **kwargs
313
+ )
314
+
315
+ # Check if LLM wants to call tools
316
+ if response.get("tool_calls"):
317
+ # Execute each tool call
318
+ for tool_call in response["tool_calls"]:
319
+ tool_result = await self._execute_tool_call(
320
+ tool_call=tool_call,
321
+ tools=tools
322
+ )
323
+
324
+ tools_called.append({
325
+ "tool": tool_call["name"],
326
+ "arguments": tool_call["arguments"],
327
+ "result": tool_result
328
+ })
329
+
330
+ # Add to conversation history
331
+ # Use universal flat format (provider-agnostic)
332
+ formatted_tool_call = {
333
+ "id": tool_call.get("id", str(len(tools_called))),
334
+ "name": tool_call["name"],
335
+ "arguments": tool_call["arguments"]
336
+ }
337
+
338
+ conversation.append({
339
+ "role": "assistant",
340
+ "tool_calls": [formatted_tool_call]
341
+ })
342
+ conversation.append({
343
+ "role": "tool",
344
+ "tool_call_id": tool_call.get("id", str(len(tools_called))),
345
+ "content": json.dumps(tool_result)
346
+ })
347
+ else:
348
+ # LLM returned final answer
349
+ return {
350
+ "result": response["content"],
351
+ "tool_calls": tools_called,
352
+ "iterations": iteration + 1
353
+ }
354
+
355
+ # Exceeded max iterations
356
+ return {
357
+ "result": f"Exceeded maximum iterations ({max_iterations}). Last response: {response.get('content', 'No response')}",
358
+ "tool_calls": tools_called,
359
+ "iterations": max_iterations,
360
+ "error": "max_iterations_exceeded"
361
+ }
362
+
363
+ def _convert_tools_to_format(self, tools: List['AgentTool']) -> List[Dict[str, Any]]:
364
+ """
365
+ Convert AgentTool list to provider-specific format.
366
+
367
+ Default implementation uses OpenAI format. Providers can override
368
+ to use their own format (e.g., Anthropic).
369
+ """
370
+ return [tool.to_openai_function() for tool in tools]
371
+
372
+ async def _execute_tool_call(
373
+ self,
374
+ tool_call: Dict[str, Any],
375
+ tools: List['AgentTool']
376
+ ) -> Any:
377
+ """Execute a single tool call with timeout and error handling."""
378
+ tool_name = tool_call["name"]
379
+ arguments = tool_call["arguments"]
380
+
381
+ # Find the tool
382
+ tool = next((t for t in tools if t.name == tool_name), None)
383
+ if not tool:
384
+ return {"error": f"Tool '{tool_name}' not found"}
385
+
386
+ # Execute with timeout
387
+ try:
388
+ result = await asyncio.wait_for(
389
+ tool.handler(arguments),
390
+ timeout=tool.timeout_seconds
391
+ )
392
+ return result
393
+ except asyncio.TimeoutError:
394
+ return {"error": f"Tool '{tool_name}' timed out after {tool.timeout_seconds}s"}
395
+ except Exception as e:
396
+ return {"error": f"Tool '{tool_name}' failed: {str(e)}"}
397
+
398
+ @abstractmethod
399
+ async def _generate_with_tools_single(
400
+ self,
401
+ messages: List[Dict[str, Any]],
402
+ tools: List[Dict[str, Any]],
403
+ **kwargs
404
+ ) -> Dict[str, Any]:
405
+ """
406
+ Single LLM call with tools (provider-specific).
407
+
408
+ This method must be implemented by each provider (OpenAI, Anthropic, etc.)
409
+ to handle their specific tool calling format.
410
+
411
+ Args:
412
+ messages: Conversation history in OpenAI format
413
+ tools: Tool specifications in provider format
414
+ **kwargs: Optional parameters
415
+
416
+ Returns:
417
+ {
418
+ "tool_calls": [...], # If LLM wants to call tools
419
+ "content": "...", # If LLM has final answer
420
+ }
421
+ """
422
+ pass
423
+
424
+ @property
425
+ def info(self) -> Dict[str, Any]:
426
+ """Get information about this LLM provider."""
427
+ return {
428
+ 'provider': self.provider_name,
429
+ 'model': self.model,
430
+ 'agent_id': self.agent_id,
431
+ 'config': {k: v for k, v in self.config.items() if 'key' not in k.lower()},
432
+ 'default_params': self.default_params,
433
+ 'tracing_enabled': True
434
+ }
435
+
436
+
437
+
438
+ # Context manager for batch LLM operations
439
+ @asynccontextmanager
440
+ async def traced_llm_batch(llm_provider: BaseLLMProvider, batch_name: str = "llm_batch"):
441
+ """
442
+ Context manager for tracing batch LLM operations.
443
+
444
+ Usage:
445
+ async with traced_llm_batch(llm, "document_analysis"):
446
+ summary = await llm.generate("Summarize: " + doc1)
447
+ analysis = await llm.generate("Analyze: " + doc2)
448
+ """
449
+ trace_manager = get_trace_manager()
450
+
451
+ async with trace_manager.span(
452
+ operation_name=f"llm_batch_{batch_name}",
453
+ trace_type=TraceType.LLM_CALL,
454
+ agent_id=llm_provider.agent_id,
455
+ llm_provider=llm_provider.provider_name,
456
+ batch_operation=batch_name
457
+ ):
458
+ try:
459
+ yield llm_provider
460
+ except Exception as e:
461
+ logger.error(f"LLM batch {batch_name} failed: {e}")
462
+ raise
463
+
464
+
465
+ # Utility functions
466
+ def get_llm_traces(agent_id: Optional[str] = None, provider: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
467
+ """Get recent LLM call traces from unified tracing."""
468
+ trace_manager = get_trace_manager()
469
+ operations = trace_manager.get_recent_operations(agent_id=agent_id, limit=limit * 2)
470
+
471
+ # Filter for LLM operations
472
+ llm_ops = [
473
+ op for op in operations
474
+ if op.get('type') == 'llm_call'
475
+ ]
476
+
477
+ # Filter by provider if specified
478
+ if provider:
479
+ llm_ops = [
480
+ op for op in llm_ops
481
+ if op.get('metadata', {}).get('llm_provider') == provider
482
+ ]
483
+
484
+ return llm_ops[:limit]
485
+
486
+
487
+ def get_llm_stats(agent_id: Optional[str] = None, provider: Optional[str] = None) -> Dict[str, Any]:
488
+ """Get LLM usage statistics from unified tracing."""
489
+ traces = get_llm_traces(agent_id, provider, limit=50)
490
+
491
+ if not traces:
492
+ return {"total_calls": 0, "success_rate": 0, "avg_latency_ms": 0}
493
+
494
+ total_calls = len(traces)
495
+ successful_calls = len([t for t in traces if t.get('status') == 'success'])
496
+
497
+ # Calculate averages
498
+ latencies = [t.get('duration_ms', 0) for t in traces if t.get('duration_ms')]
499
+ avg_latency = sum(latencies) / len(latencies) if latencies else 0
500
+
501
+ # Token aggregation (simplified for MVP)
502
+ total_tokens = 0
503
+ for trace in traces:
504
+ metadata = trace.get('metadata', {})
505
+ total_tokens += metadata.get('tokens_total', 0)
506
+
507
+ return {
508
+ "total_calls": total_calls,
509
+ "successful_calls": successful_calls,
510
+ "failed_calls": total_calls - successful_calls,
511
+ "success_rate": successful_calls / total_calls if total_calls > 0 else 0,
512
+ "avg_latency_ms": avg_latency,
513
+ "total_tokens": total_tokens,
514
+ "agent_id": agent_id,
515
+ "provider": provider
516
+ }
517
+
518
+
519
+ # Export everything
520
+ __all__ = [
521
+ # Base class
522
+ "BaseLLMProvider",
523
+
524
+ # Context managers
525
+ "traced_llm_batch",
526
+
527
+ # Utility functions
528
+ "get_llm_traces",
529
+ "get_llm_stats"
530
+ ]
daita/llm/factory.py ADDED
@@ -0,0 +1,101 @@
1
+ """
2
+ Factory for creating LLM provider instances.
3
+ """
4
+ import logging
5
+ from typing import Optional
6
+
7
+ from ..core.exceptions import LLMError
8
+ from .base import BaseLLMProvider
9
+ from .openai import OpenAIProvider
10
+ from .anthropic import AnthropicProvider
11
+ from .grok import GrokProvider
12
+ from .gemini import GeminiProvider
13
+ from .mock import MockLLMProvider
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Registry of available providers
18
+ PROVIDER_REGISTRY = {
19
+ 'openai': OpenAIProvider,
20
+ 'anthropic': AnthropicProvider,
21
+ 'grok': GrokProvider,
22
+ 'gemini': GeminiProvider,
23
+ 'mock': MockLLMProvider,
24
+ }
25
+
26
+ def create_llm_provider(
27
+ provider: str,
28
+ model: str,
29
+ api_key: Optional[str] = None,
30
+ agent_id: Optional[str] = None,
31
+ **kwargs
32
+ ) -> BaseLLMProvider:
33
+ """
34
+ Factory function to create LLM provider instances.
35
+
36
+ Args:
37
+ provider: Provider name ('openai', 'anthropic', 'grok', 'gemini', 'mock')
38
+ model: Model identifier
39
+ api_key: API key for authentication
40
+ agent_id: Agent ID for token tracking
41
+ **kwargs: Additional provider-specific parameters
42
+
43
+ Returns:
44
+ LLM provider instance
45
+
46
+ Raises:
47
+ LLMError: If provider is not supported
48
+
49
+ Examples:
50
+ >>> # Create OpenAI provider with token tracking
51
+ >>> llm = create_llm_provider('openai', 'gpt-4', api_key='sk-...', agent_id='my_agent')
52
+
53
+ >>> # Create Anthropic provider with token tracking
54
+ >>> llm = create_llm_provider('anthropic', 'claude-3-sonnet-20240229', agent_id='my_agent')
55
+
56
+ >>> # Create Grok provider
57
+ >>> llm = create_llm_provider('grok', 'grok-beta', api_key='xai-...', agent_id='my_agent')
58
+
59
+ >>> # Create Gemini provider
60
+ >>> llm = create_llm_provider('gemini', 'gemini-1.5-pro', api_key='AIza...', agent_id='my_agent')
61
+
62
+ >>> # Create mock provider for testing
63
+ >>> llm = create_llm_provider('mock', 'test-model', agent_id='test_agent')
64
+ """
65
+ provider_name = provider.lower()
66
+
67
+ if provider_name not in PROVIDER_REGISTRY:
68
+ available_providers = list(PROVIDER_REGISTRY.keys())
69
+ raise LLMError(
70
+ f"Unsupported LLM provider: {provider}. "
71
+ f"Available providers: {available_providers}"
72
+ )
73
+
74
+ provider_class = PROVIDER_REGISTRY[provider_name]
75
+
76
+ try:
77
+ # Pass agent_id to provider for token tracking
78
+ return provider_class(
79
+ model=model,
80
+ api_key=api_key,
81
+ agent_id=agent_id,
82
+ **kwargs
83
+ )
84
+ except Exception as e:
85
+ logger.error(f"Failed to create {provider} provider: {str(e)}")
86
+ raise LLMError(f"Failed to create {provider} provider: {str(e)}")
87
+
88
+ def register_llm_provider(name: str, provider_class) -> None:
89
+ """
90
+ Register a custom LLM provider.
91
+
92
+ Args:
93
+ name: Provider name
94
+ provider_class: Provider class that implements LLMProvider interface
95
+ """
96
+ PROVIDER_REGISTRY[name.lower()] = provider_class
97
+ logger.info(f"Registered custom LLM provider: {name}")
98
+
99
+ def list_available_providers() -> list:
100
+ """Get list of available LLM providers."""
101
+ return list(PROVIDER_REGISTRY.keys())