daita-agents 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of daita-agents might be problematic. Click here for more details.

Files changed (69) hide show
  1. daita/__init__.py +208 -0
  2. daita/agents/__init__.py +33 -0
  3. daita/agents/base.py +722 -0
  4. daita/agents/substrate.py +895 -0
  5. daita/cli/__init__.py +145 -0
  6. daita/cli/__main__.py +7 -0
  7. daita/cli/ascii_art.py +44 -0
  8. daita/cli/core/__init__.py +0 -0
  9. daita/cli/core/create.py +254 -0
  10. daita/cli/core/deploy.py +473 -0
  11. daita/cli/core/deployments.py +309 -0
  12. daita/cli/core/import_detector.py +219 -0
  13. daita/cli/core/init.py +382 -0
  14. daita/cli/core/logs.py +239 -0
  15. daita/cli/core/managed_deploy.py +709 -0
  16. daita/cli/core/run.py +648 -0
  17. daita/cli/core/status.py +421 -0
  18. daita/cli/core/test.py +239 -0
  19. daita/cli/core/webhooks.py +172 -0
  20. daita/cli/main.py +588 -0
  21. daita/cli/utils.py +541 -0
  22. daita/config/__init__.py +62 -0
  23. daita/config/base.py +159 -0
  24. daita/config/settings.py +184 -0
  25. daita/core/__init__.py +262 -0
  26. daita/core/decision_tracing.py +701 -0
  27. daita/core/exceptions.py +480 -0
  28. daita/core/focus.py +251 -0
  29. daita/core/interfaces.py +76 -0
  30. daita/core/plugin_tracing.py +550 -0
  31. daita/core/relay.py +695 -0
  32. daita/core/reliability.py +381 -0
  33. daita/core/scaling.py +444 -0
  34. daita/core/tools.py +402 -0
  35. daita/core/tracing.py +770 -0
  36. daita/core/workflow.py +1084 -0
  37. daita/display/__init__.py +1 -0
  38. daita/display/console.py +160 -0
  39. daita/execution/__init__.py +58 -0
  40. daita/execution/client.py +856 -0
  41. daita/execution/exceptions.py +92 -0
  42. daita/execution/models.py +317 -0
  43. daita/llm/__init__.py +60 -0
  44. daita/llm/anthropic.py +166 -0
  45. daita/llm/base.py +373 -0
  46. daita/llm/factory.py +101 -0
  47. daita/llm/gemini.py +152 -0
  48. daita/llm/grok.py +114 -0
  49. daita/llm/mock.py +135 -0
  50. daita/llm/openai.py +109 -0
  51. daita/plugins/__init__.py +141 -0
  52. daita/plugins/base.py +37 -0
  53. daita/plugins/base_db.py +167 -0
  54. daita/plugins/elasticsearch.py +844 -0
  55. daita/plugins/mcp.py +481 -0
  56. daita/plugins/mongodb.py +510 -0
  57. daita/plugins/mysql.py +351 -0
  58. daita/plugins/postgresql.py +331 -0
  59. daita/plugins/redis_messaging.py +500 -0
  60. daita/plugins/rest.py +529 -0
  61. daita/plugins/s3.py +761 -0
  62. daita/plugins/slack.py +729 -0
  63. daita/utils/__init__.py +18 -0
  64. daita_agents-0.1.0.dist-info/METADATA +350 -0
  65. daita_agents-0.1.0.dist-info/RECORD +69 -0
  66. daita_agents-0.1.0.dist-info/WHEEL +5 -0
  67. daita_agents-0.1.0.dist-info/entry_points.txt +2 -0
  68. daita_agents-0.1.0.dist-info/licenses/LICENSE +56 -0
  69. daita_agents-0.1.0.dist-info/top_level.txt +1 -0
daita/llm/base.py ADDED
@@ -0,0 +1,373 @@
1
+ """
2
+ Updated BaseLLMProvider with Unified Tracing Integration
3
+
4
+ This replaces the old BaseLLMProvider to use the unified tracing system.
5
+ All LLM calls are automatically traced without user configuration.
6
+
7
+ Key Changes:
8
+ - Removed old token tracking system completely
9
+ - Integrated automatic LLM call tracing
10
+ - Simple cost estimation
11
+ - Automatic provider/model/token capture
12
+ - Zero configuration required
13
+ """
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Dict, Any, Optional, List
17
+ from contextlib import asynccontextmanager
18
+ import logging
19
+ import time
20
+ import asyncio
21
+
22
+ from ..core.tracing import get_trace_manager, TraceType, TraceStatus
23
+ from ..core.interfaces import LLMProvider
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class BaseLLMProvider(LLMProvider, ABC):
28
+ """
29
+ Base class for LLM providers with automatic call tracing.
30
+
31
+ Every LLM call is automatically traced with:
32
+ - Provider and model details
33
+ - Token usage and costs
34
+ - Latency and performance
35
+ - Input/output content (preview)
36
+ - Error tracking
37
+
38
+ Users get full LLM observability without any configuration.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model: str,
44
+ api_key: Optional[str] = None,
45
+ **kwargs
46
+ ):
47
+ """
48
+ Initialize the LLM provider with automatic tracing.
49
+
50
+ Args:
51
+ model: Model identifier
52
+ api_key: API key for authentication
53
+ **kwargs: Additional provider-specific options
54
+ """
55
+ self.model = model
56
+ self.api_key = api_key
57
+ self.config = kwargs
58
+
59
+ # Default parameters
60
+ self.default_params = {
61
+ 'temperature': kwargs.get('temperature', 0.7),
62
+ 'max_tokens': kwargs.get('max_tokens', 1000),
63
+ 'top_p': kwargs.get('top_p', 1.0),
64
+ }
65
+
66
+ # Agent ID for tracing (set by agent)
67
+ self.agent_id = kwargs.get('agent_id')
68
+
69
+ # Get trace manager for automatic tracing
70
+ self.trace_manager = get_trace_manager()
71
+
72
+ # Provider name for tracing
73
+ self.provider_name = self.__class__.__name__.replace('Provider', '').lower()
74
+
75
+ # Last usage for cost estimation
76
+ self._last_usage = None
77
+
78
+ logger.debug(f"Initialized {self.__class__.__name__} with model {model} (automatic tracing enabled)")
79
+
80
+ async def generate(self, prompt: str, **kwargs) -> str:
81
+ """
82
+ Generate text from prompt with automatic LLM call tracing.
83
+
84
+ Every call is automatically traced with full metadata.
85
+
86
+ Args:
87
+ prompt: Input prompt
88
+ **kwargs: Optional parameters to override defaults
89
+
90
+ Returns:
91
+ Generated text response
92
+ """
93
+ # Merge parameters
94
+ params = self._merge_params(kwargs)
95
+
96
+ # Automatically trace the LLM call
97
+ async with self.trace_manager.span(
98
+ operation_name=f"llm_{self.provider_name}_{self.model}",
99
+ trace_type=TraceType.LLM_CALL,
100
+ agent_id=self.agent_id,
101
+ input_data=prompt,
102
+ llm_provider=self.provider_name,
103
+ llm_model=self.model,
104
+ temperature=str(params.get('temperature', 0.7)),
105
+ max_tokens=str(params.get('max_tokens', 1000))
106
+ ) as span_id:
107
+
108
+ try:
109
+ start_time = time.time()
110
+
111
+ # Execute the actual LLM call
112
+ response = await self._generate_impl(prompt, **params)
113
+
114
+ end_time = time.time()
115
+ duration_ms = (end_time - start_time) * 1000
116
+
117
+ # Get token usage from the call
118
+ token_usage = self._get_last_token_usage()
119
+
120
+ # Record LLM call details in the trace
121
+ self.trace_manager.record_llm_call(
122
+ span_id=span_id,
123
+ provider=self.provider_name,
124
+ model=self.model,
125
+ tokens=token_usage
126
+ )
127
+
128
+ # Add cost estimation if available
129
+ cost = self._estimate_cost(token_usage)
130
+ if cost and span_id in self.trace_manager._active_spans:
131
+ span = self.trace_manager._active_spans[span_id]
132
+ span.metadata['estimated_cost_usd'] = cost
133
+ span.metadata['duration_ms'] = duration_ms
134
+
135
+ logger.debug(f"LLM call completed: {token_usage.get('total_tokens', 0)} tokens in {duration_ms:.1f}ms")
136
+ return response
137
+
138
+ except Exception as e:
139
+ # LLM call failed - error automatically recorded by span context
140
+ logger.warning(f"LLM call failed: {str(e)}")
141
+ raise
142
+
143
+ @abstractmethod
144
+ async def _generate_impl(self, prompt: str, **kwargs) -> str:
145
+ """
146
+ Provider-specific implementation of text generation.
147
+
148
+ This method must be implemented by each provider (OpenAI, Anthropic, etc.)
149
+ and contains the actual LLM API call logic.
150
+
151
+ Args:
152
+ prompt: Input prompt
153
+ **kwargs: Optional parameters
154
+
155
+ Returns:
156
+ Generated text response
157
+ """
158
+ pass
159
+
160
+ def _get_last_token_usage(self) -> Dict[str, int]:
161
+ """
162
+ Get token usage from the last API call.
163
+
164
+ This should be overridden by each provider to return actual
165
+ token usage from their last API response.
166
+ """
167
+ if self._last_usage:
168
+ # Try to extract from stored usage object
169
+ if hasattr(self._last_usage, 'total_tokens'):
170
+ # OpenAI format
171
+ return {
172
+ 'total_tokens': getattr(self._last_usage, 'total_tokens', 0),
173
+ 'prompt_tokens': getattr(self._last_usage, 'prompt_tokens', 0),
174
+ 'completion_tokens': getattr(self._last_usage, 'completion_tokens', 0)
175
+ }
176
+ elif hasattr(self._last_usage, 'input_tokens'):
177
+ # Anthropic format
178
+ input_tokens = getattr(self._last_usage, 'input_tokens', 0)
179
+ output_tokens = getattr(self._last_usage, 'output_tokens', 0)
180
+ return {
181
+ 'total_tokens': input_tokens + output_tokens,
182
+ 'prompt_tokens': input_tokens,
183
+ 'completion_tokens': output_tokens
184
+ }
185
+
186
+ # Default fallback
187
+ return {
188
+ 'total_tokens': 0,
189
+ 'prompt_tokens': 0,
190
+ 'completion_tokens': 0
191
+ }
192
+
193
+ def _estimate_cost(self, token_usage: Dict[str, int]) -> Optional[float]:
194
+ """
195
+ Simple cost estimation for MVP.
196
+
197
+ Providers can override with specific pricing.
198
+ """
199
+ total_tokens = token_usage.get('total_tokens', 0)
200
+ if total_tokens == 0:
201
+ return None
202
+
203
+ # Generic estimation for MVP - providers should override
204
+ cost_per_1k_tokens = 0.002 # $0.002 per 1K tokens
205
+ return (total_tokens / 1000) * cost_per_1k_tokens
206
+
207
+ def _merge_params(self, override_params: Dict[str, Any]) -> Dict[str, Any]:
208
+ """Merge default parameters with overrides."""
209
+ params = self.default_params.copy()
210
+ params.update(override_params)
211
+ return params
212
+
213
+ def _validate_api_key(self) -> None:
214
+ """Validate that API key is available."""
215
+ if not self.api_key:
216
+ raise ValueError(f"API key required for {self.__class__.__name__}")
217
+
218
+ def set_agent_id(self, agent_id: str):
219
+ """
220
+ Set the agent ID for tracing context.
221
+
222
+ This is called automatically by BaseAgent during initialization.
223
+ """
224
+ self.agent_id = agent_id
225
+ logger.debug(f"Set agent ID {agent_id} for {self.provider_name} provider")
226
+
227
+ def get_recent_calls(self, limit: int = 10) -> List[Dict[str, Any]]:
228
+ """Get recent LLM calls for this provider's agent from unified tracing."""
229
+ if not self.agent_id:
230
+ return []
231
+
232
+ operations = self.trace_manager.get_recent_operations(agent_id=self.agent_id, limit=limit * 2)
233
+
234
+ # Filter for LLM calls from this provider
235
+ llm_calls = [
236
+ op for op in operations
237
+ if (op.get('type') == 'llm_call' and
238
+ op.get('metadata', {}).get('llm_provider') == self.provider_name)
239
+ ]
240
+
241
+ return llm_calls[:limit]
242
+
243
+ def get_token_stats(self) -> Dict[str, Any]:
244
+ """Get token usage statistics from unified tracing."""
245
+ if not self.agent_id:
246
+ return {
247
+ 'total_calls': 0,
248
+ 'total_tokens': 0,
249
+ 'estimated_cost': 0.0
250
+ }
251
+
252
+ metrics = self.trace_manager.get_agent_metrics(self.agent_id)
253
+
254
+ # Get actual token usage from the most recent call
255
+ last_usage = self._get_last_token_usage()
256
+
257
+ return {
258
+ 'total_calls': metrics.get('total_operations', 0), # All operations
259
+ 'total_tokens': last_usage.get('total_tokens', 0), # From last API call
260
+ 'prompt_tokens': last_usage.get('prompt_tokens', 0),
261
+ 'completion_tokens': last_usage.get('completion_tokens', 0),
262
+ 'estimated_cost': self._estimate_cost(last_usage) or 0.0,
263
+ 'success_rate': metrics.get('success_rate', 0),
264
+ 'avg_latency_ms': metrics.get('avg_latency_ms', 0)
265
+ }
266
+
267
+ @property
268
+ def info(self) -> Dict[str, Any]:
269
+ """Get information about this LLM provider."""
270
+ return {
271
+ 'provider': self.provider_name,
272
+ 'model': self.model,
273
+ 'agent_id': self.agent_id,
274
+ 'config': {k: v for k, v in self.config.items() if 'key' not in k.lower()},
275
+ 'default_params': self.default_params,
276
+ 'tracing_enabled': True
277
+ }
278
+
279
+
280
+
281
+ # Context manager for batch LLM operations
282
+ @asynccontextmanager
283
+ async def traced_llm_batch(llm_provider: BaseLLMProvider, batch_name: str = "llm_batch"):
284
+ """
285
+ Context manager for tracing batch LLM operations.
286
+
287
+ Usage:
288
+ async with traced_llm_batch(llm, "document_analysis"):
289
+ summary = await llm.generate("Summarize: " + doc1)
290
+ analysis = await llm.generate("Analyze: " + doc2)
291
+ """
292
+ trace_manager = get_trace_manager()
293
+
294
+ async with trace_manager.span(
295
+ operation_name=f"llm_batch_{batch_name}",
296
+ trace_type=TraceType.LLM_CALL,
297
+ agent_id=llm_provider.agent_id,
298
+ llm_provider=llm_provider.provider_name,
299
+ batch_operation=batch_name
300
+ ):
301
+ try:
302
+ yield llm_provider
303
+ except Exception as e:
304
+ logger.error(f"LLM batch {batch_name} failed: {e}")
305
+ raise
306
+
307
+
308
+ # Utility functions
309
+ def get_llm_traces(agent_id: Optional[str] = None, provider: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
310
+ """Get recent LLM call traces from unified tracing."""
311
+ trace_manager = get_trace_manager()
312
+ operations = trace_manager.get_recent_operations(agent_id=agent_id, limit=limit * 2)
313
+
314
+ # Filter for LLM operations
315
+ llm_ops = [
316
+ op for op in operations
317
+ if op.get('type') == 'llm_call'
318
+ ]
319
+
320
+ # Filter by provider if specified
321
+ if provider:
322
+ llm_ops = [
323
+ op for op in llm_ops
324
+ if op.get('metadata', {}).get('llm_provider') == provider
325
+ ]
326
+
327
+ return llm_ops[:limit]
328
+
329
+
330
+ def get_llm_stats(agent_id: Optional[str] = None, provider: Optional[str] = None) -> Dict[str, Any]:
331
+ """Get LLM usage statistics from unified tracing."""
332
+ traces = get_llm_traces(agent_id, provider, limit=50)
333
+
334
+ if not traces:
335
+ return {"total_calls": 0, "success_rate": 0, "avg_latency_ms": 0}
336
+
337
+ total_calls = len(traces)
338
+ successful_calls = len([t for t in traces if t.get('status') == 'success'])
339
+
340
+ # Calculate averages
341
+ latencies = [t.get('duration_ms', 0) for t in traces if t.get('duration_ms')]
342
+ avg_latency = sum(latencies) / len(latencies) if latencies else 0
343
+
344
+ # Token aggregation (simplified for MVP)
345
+ total_tokens = 0
346
+ for trace in traces:
347
+ metadata = trace.get('metadata', {})
348
+ total_tokens += metadata.get('tokens_total', 0)
349
+
350
+ return {
351
+ "total_calls": total_calls,
352
+ "successful_calls": successful_calls,
353
+ "failed_calls": total_calls - successful_calls,
354
+ "success_rate": successful_calls / total_calls if total_calls > 0 else 0,
355
+ "avg_latency_ms": avg_latency,
356
+ "total_tokens": total_tokens,
357
+ "agent_id": agent_id,
358
+ "provider": provider
359
+ }
360
+
361
+
362
+ # Export everything
363
+ __all__ = [
364
+ # Base class
365
+ "BaseLLMProvider",
366
+
367
+ # Context managers
368
+ "traced_llm_batch",
369
+
370
+ # Utility functions
371
+ "get_llm_traces",
372
+ "get_llm_stats"
373
+ ]
daita/llm/factory.py ADDED
@@ -0,0 +1,101 @@
1
+ """
2
+ Factory for creating LLM provider instances.
3
+ """
4
+ import logging
5
+ from typing import Optional
6
+
7
+ from ..core.exceptions import LLMError
8
+ from .base import BaseLLMProvider
9
+ from .openai import OpenAIProvider
10
+ from .anthropic import AnthropicProvider
11
+ from .grok import GrokProvider
12
+ from .gemini import GeminiProvider
13
+ from .mock import MockLLMProvider
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Registry of available providers
18
+ PROVIDER_REGISTRY = {
19
+ 'openai': OpenAIProvider,
20
+ 'anthropic': AnthropicProvider,
21
+ 'grok': GrokProvider,
22
+ 'gemini': GeminiProvider,
23
+ 'mock': MockLLMProvider,
24
+ }
25
+
26
+ def create_llm_provider(
27
+ provider: str,
28
+ model: str,
29
+ api_key: Optional[str] = None,
30
+ agent_id: Optional[str] = None,
31
+ **kwargs
32
+ ) -> BaseLLMProvider:
33
+ """
34
+ Factory function to create LLM provider instances.
35
+
36
+ Args:
37
+ provider: Provider name ('openai', 'anthropic', 'grok', 'gemini', 'mock')
38
+ model: Model identifier
39
+ api_key: API key for authentication
40
+ agent_id: Agent ID for token tracking
41
+ **kwargs: Additional provider-specific parameters
42
+
43
+ Returns:
44
+ LLM provider instance
45
+
46
+ Raises:
47
+ LLMError: If provider is not supported
48
+
49
+ Examples:
50
+ >>> # Create OpenAI provider with token tracking
51
+ >>> llm = create_llm_provider('openai', 'gpt-4', api_key='sk-...', agent_id='my_agent')
52
+
53
+ >>> # Create Anthropic provider with token tracking
54
+ >>> llm = create_llm_provider('anthropic', 'claude-3-sonnet-20240229', agent_id='my_agent')
55
+
56
+ >>> # Create Grok provider
57
+ >>> llm = create_llm_provider('grok', 'grok-beta', api_key='xai-...', agent_id='my_agent')
58
+
59
+ >>> # Create Gemini provider
60
+ >>> llm = create_llm_provider('gemini', 'gemini-1.5-pro', api_key='AIza...', agent_id='my_agent')
61
+
62
+ >>> # Create mock provider for testing
63
+ >>> llm = create_llm_provider('mock', 'test-model', agent_id='test_agent')
64
+ """
65
+ provider_name = provider.lower()
66
+
67
+ if provider_name not in PROVIDER_REGISTRY:
68
+ available_providers = list(PROVIDER_REGISTRY.keys())
69
+ raise LLMError(
70
+ f"Unsupported LLM provider: {provider}. "
71
+ f"Available providers: {available_providers}"
72
+ )
73
+
74
+ provider_class = PROVIDER_REGISTRY[provider_name]
75
+
76
+ try:
77
+ # Pass agent_id to provider for token tracking
78
+ return provider_class(
79
+ model=model,
80
+ api_key=api_key,
81
+ agent_id=agent_id,
82
+ **kwargs
83
+ )
84
+ except Exception as e:
85
+ logger.error(f"Failed to create {provider} provider: {str(e)}")
86
+ raise LLMError(f"Failed to create {provider} provider: {str(e)}")
87
+
88
+ def register_llm_provider(name: str, provider_class) -> None:
89
+ """
90
+ Register a custom LLM provider.
91
+
92
+ Args:
93
+ name: Provider name
94
+ provider_class: Provider class that implements LLMProvider interface
95
+ """
96
+ PROVIDER_REGISTRY[name.lower()] = provider_class
97
+ logger.info(f"Registered custom LLM provider: {name}")
98
+
99
+ def list_available_providers() -> list:
100
+ """Get list of available LLM providers."""
101
+ return list(PROVIDER_REGISTRY.keys())
daita/llm/gemini.py ADDED
@@ -0,0 +1,152 @@
1
+ """
2
+ Google Gemini LLM provider implementation with integrated tracing.
3
+ """
4
+ import os
5
+ import logging
6
+ import asyncio
7
+ from typing import Dict, Any, Optional
8
+
9
+ from ..core.exceptions import LLMError
10
+ from .base import BaseLLMProvider
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class GeminiProvider(BaseLLMProvider):
15
+ """Google Gemini LLM provider implementation with automatic call tracing."""
16
+
17
+ def __init__(
18
+ self,
19
+ model: str = "gemini-1.5-flash",
20
+ api_key: Optional[str] = None,
21
+ **kwargs
22
+ ):
23
+ """
24
+ Initialize Gemini provider.
25
+
26
+ Args:
27
+ model: Gemini model name (e.g., "gemini-1.5-flash", "gemini-1.5-pro", "gemini-1.0-pro")
28
+ api_key: Google AI API key
29
+ **kwargs: Additional Gemini-specific parameters
30
+ """
31
+ # Get API key from parameter or environment
32
+ api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
33
+
34
+ super().__init__(model=model, api_key=api_key, **kwargs)
35
+
36
+ # Gemini-specific default parameters
37
+ self.default_params.update({
38
+ 'timeout': kwargs.get('timeout', 60),
39
+ 'safety_settings': kwargs.get('safety_settings', None),
40
+ 'generation_config': kwargs.get('generation_config', None)
41
+ })
42
+
43
+ # Lazy-load Gemini client
44
+ self._client = None
45
+
46
+ @property
47
+ def client(self):
48
+ """Lazy-load Google Generative AI client."""
49
+ if self._client is None:
50
+ try:
51
+ import google.generativeai as genai
52
+ self._validate_api_key()
53
+
54
+ # Configure the API key
55
+ genai.configure(api_key=self.api_key)
56
+
57
+ # Create the generative model
58
+ self._client = genai.GenerativeModel(self.model)
59
+ logger.debug("Gemini client initialized")
60
+ except ImportError:
61
+ raise LLMError(
62
+ "Google Generative AI package not installed. Install with: pip install google-generativeai"
63
+ )
64
+ return self._client
65
+
66
+ async def _generate_impl(self, prompt: str, **kwargs) -> str:
67
+ """
68
+ Provider-specific implementation of text generation for Gemini.
69
+
70
+ This method contains the actual Gemini API call logic and is automatically
71
+ wrapped with tracing by the base class generate() method.
72
+
73
+ Args:
74
+ prompt: Input prompt
75
+ **kwargs: Optional parameters
76
+
77
+ Returns:
78
+ Generated text response
79
+ """
80
+ try:
81
+ # Merge parameters
82
+ params = self._merge_params(kwargs)
83
+
84
+ # Prepare generation config
85
+ generation_config = params.get('generation_config', {})
86
+ if not generation_config:
87
+ generation_config = {
88
+ 'max_output_tokens': params.get('max_tokens'),
89
+ 'temperature': params.get('temperature'),
90
+ 'top_p': params.get('top_p')
91
+ }
92
+
93
+ # Make API call (Gemini's generate_content can be sync or async)
94
+ # For consistency with other providers, we'll run in executor if needed
95
+ if asyncio.iscoroutinefunction(self.client.generate_content):
96
+ response = await self.client.generate_content(
97
+ prompt,
98
+ generation_config=generation_config,
99
+ safety_settings=params.get('safety_settings')
100
+ )
101
+ else:
102
+ # Run synchronous method in executor
103
+ loop = asyncio.get_event_loop()
104
+ response = await loop.run_in_executor(
105
+ None,
106
+ lambda: self.client.generate_content(
107
+ prompt,
108
+ generation_config=generation_config,
109
+ safety_settings=params.get('safety_settings')
110
+ )
111
+ )
112
+
113
+ # Store usage info if available (Gemini's usage tracking varies)
114
+ if hasattr(response, 'usage_metadata'):
115
+ self._last_usage = response.usage_metadata
116
+
117
+ return response.text
118
+
119
+ except Exception as e:
120
+ logger.error(f"Gemini generation failed: {str(e)}")
121
+ raise LLMError(f"Gemini generation failed: {str(e)}")
122
+
123
+ def _get_last_token_usage(self) -> Dict[str, int]:
124
+ """
125
+ Override base class method to handle Gemini's token format.
126
+
127
+ Gemini uses different token field names in usage_metadata.
128
+ """
129
+ if self._last_usage:
130
+ # Gemini format varies, try to extract what we can
131
+ prompt_tokens = getattr(self._last_usage, 'prompt_token_count', 0)
132
+ completion_tokens = getattr(self._last_usage, 'candidates_token_count', 0)
133
+ total_tokens = getattr(self._last_usage, 'total_token_count', prompt_tokens + completion_tokens)
134
+
135
+ return {
136
+ 'total_tokens': total_tokens,
137
+ 'prompt_tokens': prompt_tokens,
138
+ 'completion_tokens': completion_tokens
139
+ }
140
+
141
+ # Fallback to base class estimation
142
+ return super()._get_last_token_usage()
143
+
144
+ @property
145
+ def info(self) -> Dict[str, Any]:
146
+ """Get information about the Gemini provider."""
147
+ base_info = super().info
148
+ base_info.update({
149
+ 'provider_name': 'Google Gemini',
150
+ 'api_compatible': 'Google AI'
151
+ })
152
+ return base_info