stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,257 @@
1
+ """Cost tracking module for LLM API calls."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from typing import Dict, List, Optional
6
+ from collections import defaultdict
7
+
8
+
9
+ @dataclass
10
+ class CostEntry:
11
+ """Individual cost entry for an LLM API call."""
12
+
13
+ timestamp: datetime
14
+ provider: str
15
+ model: str
16
+ prompt_tokens: int
17
+ completion_tokens: int
18
+ total_tokens: int
19
+ cost_usd: float
20
+ request_id: str
21
+ cached_tokens: int = 0
22
+ cache_creation_tokens: int = 0
23
+ cache_read_tokens: int = 0
24
+ group: Optional[str] = None
25
+
26
+
27
+ class CostTracker:
28
+ """
29
+ Track and analyze costs across LLM API calls.
30
+
31
+ Features:
32
+ - Call history with detailed metrics
33
+ - Grouping by provider, model, or custom tags
34
+ - Cost analytics and reporting
35
+ - Budget tracking and alerts
36
+ """
37
+
38
+ def __init__(self):
39
+ """Initialize cost tracker."""
40
+ self._entries: List[CostEntry] = []
41
+ self._total_cost: float = 0.0
42
+ self._budget_limit: Optional[float] = None
43
+ self._alert_threshold: Optional[float] = None
44
+
45
+ def add_entry(
46
+ self,
47
+ provider: str,
48
+ model: str,
49
+ prompt_tokens: int,
50
+ completion_tokens: int,
51
+ total_tokens: int,
52
+ cost_usd: float,
53
+ request_id: str,
54
+ cached_tokens: int = 0,
55
+ cache_creation_tokens: int = 0,
56
+ cache_read_tokens: int = 0,
57
+ group: Optional[str] = None,
58
+ ) -> None:
59
+ """
60
+ Add a cost entry to the tracker.
61
+
62
+ Args:
63
+ provider: Provider name (e.g., 'openai', 'anthropic')
64
+ model: Model name
65
+ prompt_tokens: Number of prompt tokens
66
+ completion_tokens: Number of completion tokens
67
+ total_tokens: Total tokens used
68
+ cost_usd: Cost in USD
69
+ request_id: Unique request identifier
70
+ cached_tokens: Number of cached tokens
71
+ cache_creation_tokens: Tokens written to cache
72
+ cache_read_tokens: Tokens read from cache
73
+ group: Optional group tag for categorization
74
+ """
75
+ entry = CostEntry(
76
+ timestamp=datetime.now(),
77
+ provider=provider,
78
+ model=model,
79
+ prompt_tokens=prompt_tokens,
80
+ completion_tokens=completion_tokens,
81
+ total_tokens=total_tokens,
82
+ cost_usd=cost_usd,
83
+ request_id=request_id,
84
+ cached_tokens=cached_tokens,
85
+ cache_creation_tokens=cache_creation_tokens,
86
+ cache_read_tokens=cache_read_tokens,
87
+ group=group,
88
+ )
89
+ self._entries.append(entry)
90
+ self._total_cost += cost_usd
91
+
92
+ # Check budget alerts
93
+ if self._alert_threshold and self._total_cost >= self._alert_threshold:
94
+ self._trigger_alert(self._total_cost, self._alert_threshold)
95
+
96
+ def get_total_cost(self) -> float:
97
+ """Get total cost across all tracked calls."""
98
+ return self._total_cost
99
+
100
+ def get_total_tokens(self) -> int:
101
+ """Get total tokens across all tracked calls."""
102
+ return sum(entry.total_tokens for entry in self._entries)
103
+
104
+ def get_call_count(self) -> int:
105
+ """Get total number of tracked calls."""
106
+ return len(self._entries)
107
+
108
+ def get_entries(
109
+ self,
110
+ provider: Optional[str] = None,
111
+ model: Optional[str] = None,
112
+ group: Optional[str] = None,
113
+ ) -> List[CostEntry]:
114
+ """
115
+ Get filtered cost entries.
116
+
117
+ Args:
118
+ provider: Filter by provider name
119
+ model: Filter by model name
120
+ group: Filter by group tag
121
+
122
+ Returns:
123
+ List of matching cost entries
124
+ """
125
+ entries = self._entries
126
+
127
+ if provider:
128
+ entries = [e for e in entries if e.provider == provider]
129
+ if model:
130
+ entries = [e for e in entries if e.model == model]
131
+ if group:
132
+ entries = [e for e in entries if e.group == group]
133
+
134
+ return entries
135
+
136
+ def get_cost_by_provider(self) -> Dict[str, float]:
137
+ """Get total cost grouped by provider."""
138
+ costs: Dict[str, float] = defaultdict(float)
139
+ for entry in self._entries:
140
+ costs[entry.provider] += entry.cost_usd
141
+ return dict(costs)
142
+
143
+ def get_cost_by_model(self) -> Dict[str, float]:
144
+ """Get total cost grouped by model."""
145
+ costs: Dict[str, float] = defaultdict(float)
146
+ for entry in self._entries:
147
+ costs[entry.model] += entry.cost_usd
148
+ return dict(costs)
149
+
150
+ def get_cost_by_group(self) -> Dict[str, float]:
151
+ """Get total cost grouped by custom group tag."""
152
+ costs: Dict[str, float] = defaultdict(float)
153
+ for entry in self._entries:
154
+ if entry.group:
155
+ costs[entry.group] += entry.cost_usd
156
+ return dict(costs)
157
+
158
+ def get_tokens_by_provider(self) -> Dict[str, int]:
159
+ """Get total tokens grouped by provider."""
160
+ tokens: Dict[str, int] = defaultdict(int)
161
+ for entry in self._entries:
162
+ tokens[entry.provider] += entry.total_tokens
163
+ return dict(tokens)
164
+
165
+ def get_cache_stats(self) -> Dict[str, any]:
166
+ """Get cache usage statistics."""
167
+ total_cache_reads = sum(e.cache_read_tokens for e in self._entries)
168
+ total_cache_creates = sum(e.cache_creation_tokens for e in self._entries)
169
+ total_prompt_tokens = sum(e.prompt_tokens for e in self._entries)
170
+
171
+ cache_hit_rate = 0.0
172
+ if total_prompt_tokens > 0:
173
+ cache_hit_rate = (total_cache_reads / total_prompt_tokens) * 100
174
+
175
+ return {
176
+ "total_cache_read_tokens": total_cache_reads,
177
+ "total_cache_creation_tokens": total_cache_creates,
178
+ "cache_hit_rate_percent": round(cache_hit_rate, 2),
179
+ }
180
+
181
+ def set_budget(self, limit: float, alert_threshold: Optional[float] = None) -> None:
182
+ """
183
+ Set budget limit and optional alert threshold.
184
+
185
+ Args:
186
+ limit: Maximum budget in USD
187
+ alert_threshold: Alert when cost reaches this threshold (default: 80% of limit)
188
+ """
189
+ self._budget_limit = limit
190
+ self._alert_threshold = alert_threshold or (limit * 0.8)
191
+
192
+ def get_budget_status(self) -> Dict[str, any]:
193
+ """
194
+ Get current budget status.
195
+
196
+ Returns:
197
+ Dictionary with budget information
198
+ """
199
+ if self._budget_limit is None:
200
+ return {
201
+ "budget_set": False,
202
+ "total_cost": self._total_cost,
203
+ }
204
+
205
+ remaining = self._budget_limit - self._total_cost
206
+ percent_used = (self._total_cost / self._budget_limit) * 100
207
+
208
+ return {
209
+ "budget_set": True,
210
+ "budget_limit": self._budget_limit,
211
+ "total_cost": self._total_cost,
212
+ "remaining": max(0, remaining),
213
+ "percent_used": round(percent_used, 2),
214
+ "over_budget": self._total_cost > self._budget_limit,
215
+ "alert_threshold": self._alert_threshold,
216
+ }
217
+
218
+ def is_over_budget(self) -> bool:
219
+ """Check if current spending exceeds budget limit."""
220
+ if self._budget_limit is None:
221
+ return False
222
+ return self._total_cost > self._budget_limit
223
+
224
+ def reset(self) -> None:
225
+ """Reset all tracked data."""
226
+ self._entries.clear()
227
+ self._total_cost = 0.0
228
+
229
+ def _trigger_alert(self, current_cost: float, threshold: float) -> None:
230
+ """
231
+ Trigger budget alert (can be overridden for custom behavior).
232
+
233
+ Args:
234
+ current_cost: Current total cost
235
+ threshold: Alert threshold that was exceeded
236
+ """
237
+ # Default implementation: print warning
238
+ # Override this method for custom alert behavior (email, webhook, etc.)
239
+ print(f"⚠️ Budget Alert: Current cost ${current_cost:.4f} exceeds threshold ${threshold:.4f}")
240
+
241
+ def get_summary(self) -> Dict[str, any]:
242
+ """
243
+ Get comprehensive summary of tracked costs.
244
+
245
+ Returns:
246
+ Dictionary with summary statistics
247
+ """
248
+ return {
249
+ "total_cost": self._total_cost,
250
+ "total_tokens": self.get_total_tokens(),
251
+ "total_calls": self.get_call_count(),
252
+ "cost_by_provider": self.get_cost_by_provider(),
253
+ "cost_by_model": self.get_cost_by_model(),
254
+ "tokens_by_provider": self.get_tokens_by_provider(),
255
+ "cache_stats": self.get_cache_stats(),
256
+ "budget_status": self.get_budget_status(),
257
+ }
@@ -0,0 +1,245 @@
1
+ """Embedding generation for RAG and semantic search.
2
+
3
+ This module provides abstraction for generating embeddings from text using
4
+ various provider APIs (OpenAI, Cohere, etc.).
5
+ """
6
+
7
+ import asyncio
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass
10
+ from typing import List, Optional
11
+ import os
12
+ from openai import AsyncOpenAI
13
+
14
+ from .exceptions import ProviderAPIError, AuthenticationError
15
+
16
+
17
+ @dataclass
18
+ class EmbeddingResult:
19
+ """Result of an embedding generation request.
20
+
21
+ Attributes:
22
+ embeddings: List of embedding vectors (each is List[float])
23
+ model: Name of the embedding model used
24
+ total_tokens: Total tokens processed
25
+ cost: Cost of the embedding request in USD
26
+ """
27
+ embeddings: List[List[float]]
28
+ model: str
29
+ total_tokens: int
30
+ cost: float
31
+
32
+
33
+ class EmbeddingProvider(ABC):
34
+ """Abstract base class for embedding providers.
35
+
36
+ All embedding provider implementations must inherit from this class
37
+ and implement the generate_embeddings method.
38
+ """
39
+
40
+ @abstractmethod
41
+ async def generate_embeddings(
42
+ self,
43
+ texts: List[str],
44
+ model: Optional[str] = None
45
+ ) -> EmbeddingResult:
46
+ """Generate embeddings for a list of texts.
47
+
48
+ Args:
49
+ texts: List of text strings to embed
50
+ model: Optional model name override
51
+
52
+ Returns:
53
+ EmbeddingResult with embeddings and metadata
54
+
55
+ Raises:
56
+ ProviderAPIError: If the API request fails
57
+ AuthenticationError: If authentication fails
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def get_embedding_dimension(self, model: str) -> int:
63
+ """Get the dimensionality of embeddings for a given model.
64
+
65
+ Args:
66
+ model: Model name
67
+
68
+ Returns:
69
+ Embedding dimension (e.g., 1536 for text-embedding-3-small)
70
+ """
71
+ pass
72
+
73
+ def generate_embeddings_sync(
74
+ self,
75
+ texts: List[str],
76
+ model: Optional[str] = None
77
+ ) -> EmbeddingResult:
78
+ """Synchronous wrapper for generate_embeddings."""
79
+ return asyncio.run(self.generate_embeddings(texts, model))
80
+
81
+
82
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
83
+ """OpenAI embedding provider implementation.
84
+
85
+ Supports:
86
+ - text-embedding-3-small (1536 dimensions)
87
+ - text-embedding-3-large (3072 dimensions)
88
+ - text-embedding-ada-002 (1536 dimensions, legacy)
89
+ """
90
+
91
+ # Embedding costs per 1M tokens (as of Feb 2026)
92
+ EMBEDDING_COSTS = {
93
+ "text-embedding-3-small": 0.020 / 1_000_000, # $0.020 per 1M tokens
94
+ "text-embedding-3-large": 0.130 / 1_000_000, # $0.130 per 1M tokens
95
+ "text-embedding-ada-002": 0.100 / 1_000_000, # $0.100 per 1M tokens (legacy)
96
+ }
97
+
98
+ # Embedding dimensions by model
99
+ EMBEDDING_DIMENSIONS = {
100
+ "text-embedding-3-small": 1536,
101
+ "text-embedding-3-large": 3072,
102
+ "text-embedding-ada-002": 1536,
103
+ }
104
+
105
+ # Default model
106
+ DEFAULT_MODEL = "text-embedding-3-small"
107
+
108
+ def __init__(self, api_key: Optional[str] = None):
109
+ """Initialize OpenAI embedding provider.
110
+
111
+ Args:
112
+ api_key: OpenAI API key. If None, reads from OPENAI_API_KEY env var.
113
+
114
+ Raises:
115
+ AuthenticationError: If no API key is provided or found
116
+ """
117
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
118
+ if not self.api_key:
119
+ raise AuthenticationError(
120
+ "OpenAI API key not provided. Set OPENAI_API_KEY environment variable "
121
+ "or pass api_key parameter."
122
+ )
123
+
124
+ self.client = AsyncOpenAI(api_key=self.api_key)
125
+
126
+ async def generate_embeddings(
127
+ self,
128
+ texts: List[str],
129
+ model: Optional[str] = None
130
+ ) -> EmbeddingResult:
131
+ """Generate embeddings for a list of texts using OpenAI.
132
+
133
+ Args:
134
+ texts: List of text strings to embed
135
+ model: Model name (default: text-embedding-3-small)
136
+
137
+ Returns:
138
+ EmbeddingResult with embeddings and metadata
139
+
140
+ Raises:
141
+ ProviderAPIError: If the API request fails
142
+ AuthenticationError: If authentication fails
143
+ """
144
+ model = model or self.DEFAULT_MODEL
145
+
146
+ if model not in self.EMBEDDING_COSTS:
147
+ raise ValueError(
148
+ f"Unknown OpenAI embedding model: {model}. "
149
+ f"Supported models: {list(self.EMBEDDING_COSTS.keys())}"
150
+ )
151
+
152
+ if not texts:
153
+ return EmbeddingResult(
154
+ embeddings=[],
155
+ model=model,
156
+ total_tokens=0,
157
+ cost=0.0
158
+ )
159
+
160
+ try:
161
+ # Call OpenAI API
162
+ response = await self.client.embeddings.create(
163
+ input=texts,
164
+ model=model
165
+ )
166
+
167
+ # Extract embeddings
168
+ embeddings = [data.embedding for data in response.data]
169
+
170
+ # Calculate cost
171
+ total_tokens = response.usage.total_tokens
172
+ cost = total_tokens * self.EMBEDDING_COSTS[model]
173
+
174
+ return EmbeddingResult(
175
+ embeddings=embeddings,
176
+ model=model,
177
+ total_tokens=total_tokens,
178
+ cost=cost
179
+ )
180
+
181
+ except Exception as e:
182
+ error_msg = str(e)
183
+ if "authentication" in error_msg.lower() or "api key" in error_msg.lower():
184
+ raise AuthenticationError(f"OpenAI authentication failed: {error_msg}")
185
+ else:
186
+ raise ProviderAPIError(f"OpenAI embedding request failed: {error_msg}")
187
+
188
+ async def generate_embedding(self, text: str, model: Optional[str] = None) -> List[float]:
189
+ """Generate embedding for a single text string.
190
+
191
+ Convenience method for single text embedding.
192
+
193
+ Args:
194
+ text: Text string to embed
195
+ model: Model name (default: text-embedding-3-small)
196
+
197
+ Returns:
198
+ Embedding vector as List[float]
199
+ """
200
+ result = await self.generate_embeddings([text], model=model)
201
+ return result.embeddings[0]
202
+
203
+ def get_embedding_dimension(self, model: str) -> int:
204
+ """Get the dimensionality of embeddings for a given model.
205
+
206
+ Args:
207
+ model: Model name
208
+
209
+ Returns:
210
+ Embedding dimension
211
+
212
+ Raises:
213
+ ValueError: If model is unknown
214
+ """
215
+ if model not in self.EMBEDDING_DIMENSIONS:
216
+ raise ValueError(
217
+ f"Unknown OpenAI embedding model: {model}. "
218
+ f"Supported models: {list(self.EMBEDDING_DIMENSIONS.keys())}"
219
+ )
220
+ return self.EMBEDDING_DIMENSIONS[model]
221
+
222
+
223
+ def create_embedding_provider(
224
+ provider: str = "openai",
225
+ api_key: Optional[str] = None
226
+ ) -> EmbeddingProvider:
227
+ """Factory function to create embedding providers.
228
+
229
+ Args:
230
+ provider: Provider name (currently only "openai" supported)
231
+ api_key: API key for the provider
232
+
233
+ Returns:
234
+ EmbeddingProvider instance
235
+
236
+ Raises:
237
+ ValueError: If provider is unknown
238
+ """
239
+ if provider.lower() == "openai":
240
+ return OpenAIEmbeddingProvider(api_key=api_key)
241
+ else:
242
+ raise ValueError(
243
+ f"Unknown embedding provider: {provider}. "
244
+ f"Currently supported: openai"
245
+ )
@@ -0,0 +1,91 @@
1
+ """Custom exceptions for LLM abstraction layer."""
2
+
3
+
4
+ class LLMAbstractionError(Exception):
5
+ """Base exception for all LLM abstraction errors."""
6
+ pass
7
+
8
+
9
+ class ProviderError(LLMAbstractionError):
10
+ """Base exception for provider-specific errors."""
11
+ pass
12
+
13
+
14
+ class InvalidProviderError(ProviderError):
15
+ """Raised when an invalid provider is specified."""
16
+ pass
17
+
18
+
19
+ class ProviderAPIError(ProviderError):
20
+ """Raised when a provider API call fails."""
21
+
22
+ def __init__(self, message: str, provider: str, status_code: int = None):
23
+ self.provider = provider
24
+ self.status_code = status_code
25
+ super().__init__(f"[{provider}] {message}")
26
+
27
+
28
+ class AuthenticationError(ProviderError):
29
+ """Raised when API key authentication fails."""
30
+
31
+ def __init__(self, provider: str):
32
+ self.provider = provider
33
+ super().__init__(f"Authentication failed for {provider}. Check API key.")
34
+
35
+
36
+ class InsufficientBalanceError(ProviderError):
37
+ """Raised when provider account has insufficient balance."""
38
+
39
+ def __init__(self, provider: str):
40
+ self.provider = provider
41
+ super().__init__(f"Insufficient balance in {provider} account. Please add credits.")
42
+
43
+
44
+ class RateLimitError(ProviderError):
45
+ """Raised when rate limit is exceeded."""
46
+
47
+ def __init__(self, provider: str, retry_after: int = None):
48
+ self.provider = provider
49
+ self.retry_after = retry_after
50
+ message = f"Rate limit exceeded for {provider}"
51
+ if retry_after:
52
+ message += f". Retry after {retry_after} seconds"
53
+ super().__init__(message)
54
+
55
+
56
+ class InvalidModelError(ProviderError):
57
+ """Raised when an invalid model is specified for a provider."""
58
+
59
+ def __init__(self, model: str, provider: str):
60
+ self.model = model
61
+ self.provider = provider
62
+ super().__init__(f"Model '{model}' not supported by {provider}")
63
+
64
+
65
+ class BudgetExceededError(LLMAbstractionError):
66
+ """Raised when budget limit is exceeded."""
67
+
68
+ def __init__(self, current_cost: float, budget_limit: float):
69
+ self.current_cost = current_cost
70
+ self.budget_limit = budget_limit
71
+ super().__init__(
72
+ f"Budget limit ${budget_limit:.2f} exceeded. "
73
+ f"Current spend: ${current_cost:.2f}"
74
+ )
75
+
76
+
77
+ class MaxRetriesExceededError(LLMAbstractionError):
78
+ """Raised when maximum retry attempts are exceeded."""
79
+
80
+ def __init__(self, attempts: int, last_error: Exception):
81
+ self.attempts = attempts
82
+ self.last_error = last_error
83
+ super().__init__(
84
+ f"Maximum retry attempts ({attempts}) exceeded. "
85
+ f"Last error: {str(last_error)}"
86
+ )
87
+
88
+
89
+ class ValidationError(LLMAbstractionError):
90
+ """Raised when input validation fails."""
91
+ pass
stratifyai/models.py ADDED
@@ -0,0 +1,59 @@
1
+ """Data models for unified LLM abstraction layer."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from typing import List, Literal, Optional
6
+
7
+
8
+ @dataclass
9
+ class Message:
10
+ """Standard message format for all providers (OpenAI-compatible)."""
11
+ role: Literal["system", "user", "assistant"]
12
+ content: str
13
+ name: Optional[str] = None # For multi-agent scenarios
14
+ cache_control: Optional[dict] = None # For providers that support prompt caching (Anthropic, OpenAI)
15
+
16
+
17
+ @dataclass
18
+ class Usage:
19
+ """Token usage and cost information."""
20
+ prompt_tokens: int
21
+ completion_tokens: int
22
+ total_tokens: int
23
+ cached_tokens: int = 0 # Tokens retrieved from cache
24
+ cache_creation_tokens: int = 0 # Tokens written to cache (Anthropic)
25
+ cache_read_tokens: int = 0 # Tokens read from cache (Anthropic)
26
+ reasoning_tokens: int = 0 # For reasoning models like o1/o3
27
+ cost_usd: float = 0.0
28
+ cost_breakdown: Optional[dict] = None # Detailed cost breakdown by token type
29
+
30
+
31
+ @dataclass
32
+ class ChatRequest:
33
+ """Unified request structure for chat completions."""
34
+ model: str
35
+ messages: List[Message]
36
+ temperature: float = 0.7
37
+ max_tokens: Optional[int] = None
38
+ stream: bool = False
39
+ top_p: float = 1.0
40
+ frequency_penalty: float = 0.0
41
+ presence_penalty: float = 0.0
42
+ stop: Optional[List[str]] = None
43
+ # Provider-specific extensions
44
+ reasoning_effort: Optional[str] = None # OpenAI o1/o3
45
+ extra_params: dict = field(default_factory=dict)
46
+
47
+
48
+ @dataclass
49
+ class ChatResponse:
50
+ """Standard response from any provider."""
51
+ id: str
52
+ model: str
53
+ content: str
54
+ finish_reason: str
55
+ usage: Usage
56
+ provider: str
57
+ created_at: datetime
58
+ raw_response: dict # Original provider response for debugging
59
+ latency_ms: Optional[float] = None # Response latency in milliseconds
@@ -0,0 +1,5 @@
1
+ """LLM provider implementations."""
2
+
3
+ from .base import BaseProvider
4
+
5
+ __all__ = ["BaseProvider"]