stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Cost tracking module for LLM API calls."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CostEntry:
|
|
11
|
+
"""Individual cost entry for an LLM API call."""
|
|
12
|
+
|
|
13
|
+
timestamp: datetime
|
|
14
|
+
provider: str
|
|
15
|
+
model: str
|
|
16
|
+
prompt_tokens: int
|
|
17
|
+
completion_tokens: int
|
|
18
|
+
total_tokens: int
|
|
19
|
+
cost_usd: float
|
|
20
|
+
request_id: str
|
|
21
|
+
cached_tokens: int = 0
|
|
22
|
+
cache_creation_tokens: int = 0
|
|
23
|
+
cache_read_tokens: int = 0
|
|
24
|
+
group: Optional[str] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CostTracker:
|
|
28
|
+
"""
|
|
29
|
+
Track and analyze costs across LLM API calls.
|
|
30
|
+
|
|
31
|
+
Features:
|
|
32
|
+
- Call history with detailed metrics
|
|
33
|
+
- Grouping by provider, model, or custom tags
|
|
34
|
+
- Cost analytics and reporting
|
|
35
|
+
- Budget tracking and alerts
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
"""Initialize cost tracker."""
|
|
40
|
+
self._entries: List[CostEntry] = []
|
|
41
|
+
self._total_cost: float = 0.0
|
|
42
|
+
self._budget_limit: Optional[float] = None
|
|
43
|
+
self._alert_threshold: Optional[float] = None
|
|
44
|
+
|
|
45
|
+
def add_entry(
|
|
46
|
+
self,
|
|
47
|
+
provider: str,
|
|
48
|
+
model: str,
|
|
49
|
+
prompt_tokens: int,
|
|
50
|
+
completion_tokens: int,
|
|
51
|
+
total_tokens: int,
|
|
52
|
+
cost_usd: float,
|
|
53
|
+
request_id: str,
|
|
54
|
+
cached_tokens: int = 0,
|
|
55
|
+
cache_creation_tokens: int = 0,
|
|
56
|
+
cache_read_tokens: int = 0,
|
|
57
|
+
group: Optional[str] = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Add a cost entry to the tracker.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
provider: Provider name (e.g., 'openai', 'anthropic')
|
|
64
|
+
model: Model name
|
|
65
|
+
prompt_tokens: Number of prompt tokens
|
|
66
|
+
completion_tokens: Number of completion tokens
|
|
67
|
+
total_tokens: Total tokens used
|
|
68
|
+
cost_usd: Cost in USD
|
|
69
|
+
request_id: Unique request identifier
|
|
70
|
+
cached_tokens: Number of cached tokens
|
|
71
|
+
cache_creation_tokens: Tokens written to cache
|
|
72
|
+
cache_read_tokens: Tokens read from cache
|
|
73
|
+
group: Optional group tag for categorization
|
|
74
|
+
"""
|
|
75
|
+
entry = CostEntry(
|
|
76
|
+
timestamp=datetime.now(),
|
|
77
|
+
provider=provider,
|
|
78
|
+
model=model,
|
|
79
|
+
prompt_tokens=prompt_tokens,
|
|
80
|
+
completion_tokens=completion_tokens,
|
|
81
|
+
total_tokens=total_tokens,
|
|
82
|
+
cost_usd=cost_usd,
|
|
83
|
+
request_id=request_id,
|
|
84
|
+
cached_tokens=cached_tokens,
|
|
85
|
+
cache_creation_tokens=cache_creation_tokens,
|
|
86
|
+
cache_read_tokens=cache_read_tokens,
|
|
87
|
+
group=group,
|
|
88
|
+
)
|
|
89
|
+
self._entries.append(entry)
|
|
90
|
+
self._total_cost += cost_usd
|
|
91
|
+
|
|
92
|
+
# Check budget alerts
|
|
93
|
+
if self._alert_threshold and self._total_cost >= self._alert_threshold:
|
|
94
|
+
self._trigger_alert(self._total_cost, self._alert_threshold)
|
|
95
|
+
|
|
96
|
+
def get_total_cost(self) -> float:
|
|
97
|
+
"""Get total cost across all tracked calls."""
|
|
98
|
+
return self._total_cost
|
|
99
|
+
|
|
100
|
+
def get_total_tokens(self) -> int:
|
|
101
|
+
"""Get total tokens across all tracked calls."""
|
|
102
|
+
return sum(entry.total_tokens for entry in self._entries)
|
|
103
|
+
|
|
104
|
+
def get_call_count(self) -> int:
|
|
105
|
+
"""Get total number of tracked calls."""
|
|
106
|
+
return len(self._entries)
|
|
107
|
+
|
|
108
|
+
def get_entries(
|
|
109
|
+
self,
|
|
110
|
+
provider: Optional[str] = None,
|
|
111
|
+
model: Optional[str] = None,
|
|
112
|
+
group: Optional[str] = None,
|
|
113
|
+
) -> List[CostEntry]:
|
|
114
|
+
"""
|
|
115
|
+
Get filtered cost entries.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
provider: Filter by provider name
|
|
119
|
+
model: Filter by model name
|
|
120
|
+
group: Filter by group tag
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
List of matching cost entries
|
|
124
|
+
"""
|
|
125
|
+
entries = self._entries
|
|
126
|
+
|
|
127
|
+
if provider:
|
|
128
|
+
entries = [e for e in entries if e.provider == provider]
|
|
129
|
+
if model:
|
|
130
|
+
entries = [e for e in entries if e.model == model]
|
|
131
|
+
if group:
|
|
132
|
+
entries = [e for e in entries if e.group == group]
|
|
133
|
+
|
|
134
|
+
return entries
|
|
135
|
+
|
|
136
|
+
def get_cost_by_provider(self) -> Dict[str, float]:
|
|
137
|
+
"""Get total cost grouped by provider."""
|
|
138
|
+
costs: Dict[str, float] = defaultdict(float)
|
|
139
|
+
for entry in self._entries:
|
|
140
|
+
costs[entry.provider] += entry.cost_usd
|
|
141
|
+
return dict(costs)
|
|
142
|
+
|
|
143
|
+
def get_cost_by_model(self) -> Dict[str, float]:
|
|
144
|
+
"""Get total cost grouped by model."""
|
|
145
|
+
costs: Dict[str, float] = defaultdict(float)
|
|
146
|
+
for entry in self._entries:
|
|
147
|
+
costs[entry.model] += entry.cost_usd
|
|
148
|
+
return dict(costs)
|
|
149
|
+
|
|
150
|
+
def get_cost_by_group(self) -> Dict[str, float]:
|
|
151
|
+
"""Get total cost grouped by custom group tag."""
|
|
152
|
+
costs: Dict[str, float] = defaultdict(float)
|
|
153
|
+
for entry in self._entries:
|
|
154
|
+
if entry.group:
|
|
155
|
+
costs[entry.group] += entry.cost_usd
|
|
156
|
+
return dict(costs)
|
|
157
|
+
|
|
158
|
+
def get_tokens_by_provider(self) -> Dict[str, int]:
|
|
159
|
+
"""Get total tokens grouped by provider."""
|
|
160
|
+
tokens: Dict[str, int] = defaultdict(int)
|
|
161
|
+
for entry in self._entries:
|
|
162
|
+
tokens[entry.provider] += entry.total_tokens
|
|
163
|
+
return dict(tokens)
|
|
164
|
+
|
|
165
|
+
def get_cache_stats(self) -> Dict[str, any]:
|
|
166
|
+
"""Get cache usage statistics."""
|
|
167
|
+
total_cache_reads = sum(e.cache_read_tokens for e in self._entries)
|
|
168
|
+
total_cache_creates = sum(e.cache_creation_tokens for e in self._entries)
|
|
169
|
+
total_prompt_tokens = sum(e.prompt_tokens for e in self._entries)
|
|
170
|
+
|
|
171
|
+
cache_hit_rate = 0.0
|
|
172
|
+
if total_prompt_tokens > 0:
|
|
173
|
+
cache_hit_rate = (total_cache_reads / total_prompt_tokens) * 100
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
"total_cache_read_tokens": total_cache_reads,
|
|
177
|
+
"total_cache_creation_tokens": total_cache_creates,
|
|
178
|
+
"cache_hit_rate_percent": round(cache_hit_rate, 2),
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def set_budget(self, limit: float, alert_threshold: Optional[float] = None) -> None:
|
|
182
|
+
"""
|
|
183
|
+
Set budget limit and optional alert threshold.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
limit: Maximum budget in USD
|
|
187
|
+
alert_threshold: Alert when cost reaches this threshold (default: 80% of limit)
|
|
188
|
+
"""
|
|
189
|
+
self._budget_limit = limit
|
|
190
|
+
self._alert_threshold = alert_threshold or (limit * 0.8)
|
|
191
|
+
|
|
192
|
+
def get_budget_status(self) -> Dict[str, any]:
|
|
193
|
+
"""
|
|
194
|
+
Get current budget status.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Dictionary with budget information
|
|
198
|
+
"""
|
|
199
|
+
if self._budget_limit is None:
|
|
200
|
+
return {
|
|
201
|
+
"budget_set": False,
|
|
202
|
+
"total_cost": self._total_cost,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
remaining = self._budget_limit - self._total_cost
|
|
206
|
+
percent_used = (self._total_cost / self._budget_limit) * 100
|
|
207
|
+
|
|
208
|
+
return {
|
|
209
|
+
"budget_set": True,
|
|
210
|
+
"budget_limit": self._budget_limit,
|
|
211
|
+
"total_cost": self._total_cost,
|
|
212
|
+
"remaining": max(0, remaining),
|
|
213
|
+
"percent_used": round(percent_used, 2),
|
|
214
|
+
"over_budget": self._total_cost > self._budget_limit,
|
|
215
|
+
"alert_threshold": self._alert_threshold,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
def is_over_budget(self) -> bool:
|
|
219
|
+
"""Check if current spending exceeds budget limit."""
|
|
220
|
+
if self._budget_limit is None:
|
|
221
|
+
return False
|
|
222
|
+
return self._total_cost > self._budget_limit
|
|
223
|
+
|
|
224
|
+
def reset(self) -> None:
|
|
225
|
+
"""Reset all tracked data."""
|
|
226
|
+
self._entries.clear()
|
|
227
|
+
self._total_cost = 0.0
|
|
228
|
+
|
|
229
|
+
def _trigger_alert(self, current_cost: float, threshold: float) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Trigger budget alert (can be overridden for custom behavior).
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
current_cost: Current total cost
|
|
235
|
+
threshold: Alert threshold that was exceeded
|
|
236
|
+
"""
|
|
237
|
+
# Default implementation: print warning
|
|
238
|
+
# Override this method for custom alert behavior (email, webhook, etc.)
|
|
239
|
+
print(f"⚠️ Budget Alert: Current cost ${current_cost:.4f} exceeds threshold ${threshold:.4f}")
|
|
240
|
+
|
|
241
|
+
def get_summary(self) -> Dict[str, any]:
|
|
242
|
+
"""
|
|
243
|
+
Get comprehensive summary of tracked costs.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Dictionary with summary statistics
|
|
247
|
+
"""
|
|
248
|
+
return {
|
|
249
|
+
"total_cost": self._total_cost,
|
|
250
|
+
"total_tokens": self.get_total_tokens(),
|
|
251
|
+
"total_calls": self.get_call_count(),
|
|
252
|
+
"cost_by_provider": self.get_cost_by_provider(),
|
|
253
|
+
"cost_by_model": self.get_cost_by_model(),
|
|
254
|
+
"tokens_by_provider": self.get_tokens_by_provider(),
|
|
255
|
+
"cache_stats": self.get_cache_stats(),
|
|
256
|
+
"budget_status": self.get_budget_status(),
|
|
257
|
+
}
|
stratifyai/embeddings.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Embedding generation for RAG and semantic search.
|
|
2
|
+
|
|
3
|
+
This module provides abstraction for generating embeddings from text using
|
|
4
|
+
various provider APIs (OpenAI, Cohere, etc.).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import List, Optional
|
|
11
|
+
import os
|
|
12
|
+
from openai import AsyncOpenAI
|
|
13
|
+
|
|
14
|
+
from .exceptions import ProviderAPIError, AuthenticationError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EmbeddingResult:
|
|
19
|
+
"""Result of an embedding generation request.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
embeddings: List of embedding vectors (each is List[float])
|
|
23
|
+
model: Name of the embedding model used
|
|
24
|
+
total_tokens: Total tokens processed
|
|
25
|
+
cost: Cost of the embedding request in USD
|
|
26
|
+
"""
|
|
27
|
+
embeddings: List[List[float]]
|
|
28
|
+
model: str
|
|
29
|
+
total_tokens: int
|
|
30
|
+
cost: float
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EmbeddingProvider(ABC):
|
|
34
|
+
"""Abstract base class for embedding providers.
|
|
35
|
+
|
|
36
|
+
All embedding provider implementations must inherit from this class
|
|
37
|
+
and implement the generate_embeddings method.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def generate_embeddings(
|
|
42
|
+
self,
|
|
43
|
+
texts: List[str],
|
|
44
|
+
model: Optional[str] = None
|
|
45
|
+
) -> EmbeddingResult:
|
|
46
|
+
"""Generate embeddings for a list of texts.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
texts: List of text strings to embed
|
|
50
|
+
model: Optional model name override
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
EmbeddingResult with embeddings and metadata
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ProviderAPIError: If the API request fails
|
|
57
|
+
AuthenticationError: If authentication fails
|
|
58
|
+
"""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def get_embedding_dimension(self, model: str) -> int:
|
|
63
|
+
"""Get the dimensionality of embeddings for a given model.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
model: Model name
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Embedding dimension (e.g., 1536 for text-embedding-3-small)
|
|
70
|
+
"""
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
def generate_embeddings_sync(
|
|
74
|
+
self,
|
|
75
|
+
texts: List[str],
|
|
76
|
+
model: Optional[str] = None
|
|
77
|
+
) -> EmbeddingResult:
|
|
78
|
+
"""Synchronous wrapper for generate_embeddings."""
|
|
79
|
+
return asyncio.run(self.generate_embeddings(texts, model))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
83
|
+
"""OpenAI embedding provider implementation.
|
|
84
|
+
|
|
85
|
+
Supports:
|
|
86
|
+
- text-embedding-3-small (1536 dimensions)
|
|
87
|
+
- text-embedding-3-large (3072 dimensions)
|
|
88
|
+
- text-embedding-ada-002 (1536 dimensions, legacy)
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Embedding costs per 1M tokens (as of Feb 2026)
|
|
92
|
+
EMBEDDING_COSTS = {
|
|
93
|
+
"text-embedding-3-small": 0.020 / 1_000_000, # $0.020 per 1M tokens
|
|
94
|
+
"text-embedding-3-large": 0.130 / 1_000_000, # $0.130 per 1M tokens
|
|
95
|
+
"text-embedding-ada-002": 0.100 / 1_000_000, # $0.100 per 1M tokens (legacy)
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Embedding dimensions by model
|
|
99
|
+
EMBEDDING_DIMENSIONS = {
|
|
100
|
+
"text-embedding-3-small": 1536,
|
|
101
|
+
"text-embedding-3-large": 3072,
|
|
102
|
+
"text-embedding-ada-002": 1536,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Default model
|
|
106
|
+
DEFAULT_MODEL = "text-embedding-3-small"
|
|
107
|
+
|
|
108
|
+
def __init__(self, api_key: Optional[str] = None):
|
|
109
|
+
"""Initialize OpenAI embedding provider.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
api_key: OpenAI API key. If None, reads from OPENAI_API_KEY env var.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
AuthenticationError: If no API key is provided or found
|
|
116
|
+
"""
|
|
117
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
118
|
+
if not self.api_key:
|
|
119
|
+
raise AuthenticationError(
|
|
120
|
+
"OpenAI API key not provided. Set OPENAI_API_KEY environment variable "
|
|
121
|
+
"or pass api_key parameter."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
self.client = AsyncOpenAI(api_key=self.api_key)
|
|
125
|
+
|
|
126
|
+
async def generate_embeddings(
|
|
127
|
+
self,
|
|
128
|
+
texts: List[str],
|
|
129
|
+
model: Optional[str] = None
|
|
130
|
+
) -> EmbeddingResult:
|
|
131
|
+
"""Generate embeddings for a list of texts using OpenAI.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
texts: List of text strings to embed
|
|
135
|
+
model: Model name (default: text-embedding-3-small)
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
EmbeddingResult with embeddings and metadata
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ProviderAPIError: If the API request fails
|
|
142
|
+
AuthenticationError: If authentication fails
|
|
143
|
+
"""
|
|
144
|
+
model = model or self.DEFAULT_MODEL
|
|
145
|
+
|
|
146
|
+
if model not in self.EMBEDDING_COSTS:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Unknown OpenAI embedding model: {model}. "
|
|
149
|
+
f"Supported models: {list(self.EMBEDDING_COSTS.keys())}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if not texts:
|
|
153
|
+
return EmbeddingResult(
|
|
154
|
+
embeddings=[],
|
|
155
|
+
model=model,
|
|
156
|
+
total_tokens=0,
|
|
157
|
+
cost=0.0
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# Call OpenAI API
|
|
162
|
+
response = await self.client.embeddings.create(
|
|
163
|
+
input=texts,
|
|
164
|
+
model=model
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Extract embeddings
|
|
168
|
+
embeddings = [data.embedding for data in response.data]
|
|
169
|
+
|
|
170
|
+
# Calculate cost
|
|
171
|
+
total_tokens = response.usage.total_tokens
|
|
172
|
+
cost = total_tokens * self.EMBEDDING_COSTS[model]
|
|
173
|
+
|
|
174
|
+
return EmbeddingResult(
|
|
175
|
+
embeddings=embeddings,
|
|
176
|
+
model=model,
|
|
177
|
+
total_tokens=total_tokens,
|
|
178
|
+
cost=cost
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
except Exception as e:
|
|
182
|
+
error_msg = str(e)
|
|
183
|
+
if "authentication" in error_msg.lower() or "api key" in error_msg.lower():
|
|
184
|
+
raise AuthenticationError(f"OpenAI authentication failed: {error_msg}")
|
|
185
|
+
else:
|
|
186
|
+
raise ProviderAPIError(f"OpenAI embedding request failed: {error_msg}")
|
|
187
|
+
|
|
188
|
+
async def generate_embedding(self, text: str, model: Optional[str] = None) -> List[float]:
|
|
189
|
+
"""Generate embedding for a single text string.
|
|
190
|
+
|
|
191
|
+
Convenience method for single text embedding.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
text: Text string to embed
|
|
195
|
+
model: Model name (default: text-embedding-3-small)
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Embedding vector as List[float]
|
|
199
|
+
"""
|
|
200
|
+
result = await self.generate_embeddings([text], model=model)
|
|
201
|
+
return result.embeddings[0]
|
|
202
|
+
|
|
203
|
+
def get_embedding_dimension(self, model: str) -> int:
|
|
204
|
+
"""Get the dimensionality of embeddings for a given model.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
model: Model name
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Embedding dimension
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If model is unknown
|
|
214
|
+
"""
|
|
215
|
+
if model not in self.EMBEDDING_DIMENSIONS:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Unknown OpenAI embedding model: {model}. "
|
|
218
|
+
f"Supported models: {list(self.EMBEDDING_DIMENSIONS.keys())}"
|
|
219
|
+
)
|
|
220
|
+
return self.EMBEDDING_DIMENSIONS[model]
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def create_embedding_provider(
|
|
224
|
+
provider: str = "openai",
|
|
225
|
+
api_key: Optional[str] = None
|
|
226
|
+
) -> EmbeddingProvider:
|
|
227
|
+
"""Factory function to create embedding providers.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
provider: Provider name (currently only "openai" supported)
|
|
231
|
+
api_key: API key for the provider
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
EmbeddingProvider instance
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
ValueError: If provider is unknown
|
|
238
|
+
"""
|
|
239
|
+
if provider.lower() == "openai":
|
|
240
|
+
return OpenAIEmbeddingProvider(api_key=api_key)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Unknown embedding provider: {provider}. "
|
|
244
|
+
f"Currently supported: openai"
|
|
245
|
+
)
|
stratifyai/exceptions.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Custom exceptions for LLM abstraction layer."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LLMAbstractionError(Exception):
|
|
5
|
+
"""Base exception for all LLM abstraction errors."""
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ProviderError(LLMAbstractionError):
|
|
10
|
+
"""Base exception for provider-specific errors."""
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InvalidProviderError(ProviderError):
|
|
15
|
+
"""Raised when an invalid provider is specified."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProviderAPIError(ProviderError):
|
|
20
|
+
"""Raised when a provider API call fails."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, message: str, provider: str, status_code: int = None):
|
|
23
|
+
self.provider = provider
|
|
24
|
+
self.status_code = status_code
|
|
25
|
+
super().__init__(f"[{provider}] {message}")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AuthenticationError(ProviderError):
|
|
29
|
+
"""Raised when API key authentication fails."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, provider: str):
|
|
32
|
+
self.provider = provider
|
|
33
|
+
super().__init__(f"Authentication failed for {provider}. Check API key.")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InsufficientBalanceError(ProviderError):
|
|
37
|
+
"""Raised when provider account has insufficient balance."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, provider: str):
|
|
40
|
+
self.provider = provider
|
|
41
|
+
super().__init__(f"Insufficient balance in {provider} account. Please add credits.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RateLimitError(ProviderError):
|
|
45
|
+
"""Raised when rate limit is exceeded."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, provider: str, retry_after: int = None):
|
|
48
|
+
self.provider = provider
|
|
49
|
+
self.retry_after = retry_after
|
|
50
|
+
message = f"Rate limit exceeded for {provider}"
|
|
51
|
+
if retry_after:
|
|
52
|
+
message += f". Retry after {retry_after} seconds"
|
|
53
|
+
super().__init__(message)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class InvalidModelError(ProviderError):
|
|
57
|
+
"""Raised when an invalid model is specified for a provider."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, model: str, provider: str):
|
|
60
|
+
self.model = model
|
|
61
|
+
self.provider = provider
|
|
62
|
+
super().__init__(f"Model '{model}' not supported by {provider}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BudgetExceededError(LLMAbstractionError):
|
|
66
|
+
"""Raised when budget limit is exceeded."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, current_cost: float, budget_limit: float):
|
|
69
|
+
self.current_cost = current_cost
|
|
70
|
+
self.budget_limit = budget_limit
|
|
71
|
+
super().__init__(
|
|
72
|
+
f"Budget limit ${budget_limit:.2f} exceeded. "
|
|
73
|
+
f"Current spend: ${current_cost:.2f}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MaxRetriesExceededError(LLMAbstractionError):
|
|
78
|
+
"""Raised when maximum retry attempts are exceeded."""
|
|
79
|
+
|
|
80
|
+
def __init__(self, attempts: int, last_error: Exception):
|
|
81
|
+
self.attempts = attempts
|
|
82
|
+
self.last_error = last_error
|
|
83
|
+
super().__init__(
|
|
84
|
+
f"Maximum retry attempts ({attempts}) exceeded. "
|
|
85
|
+
f"Last error: {str(last_error)}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ValidationError(LLMAbstractionError):
|
|
90
|
+
"""Raised when input validation fails."""
|
|
91
|
+
pass
|
stratifyai/models.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Data models for unified LLM abstraction layer."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import List, Literal, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Message:
|
|
10
|
+
"""Standard message format for all providers (OpenAI-compatible)."""
|
|
11
|
+
role: Literal["system", "user", "assistant"]
|
|
12
|
+
content: str
|
|
13
|
+
name: Optional[str] = None # For multi-agent scenarios
|
|
14
|
+
cache_control: Optional[dict] = None # For providers that support prompt caching (Anthropic, OpenAI)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Usage:
|
|
19
|
+
"""Token usage and cost information."""
|
|
20
|
+
prompt_tokens: int
|
|
21
|
+
completion_tokens: int
|
|
22
|
+
total_tokens: int
|
|
23
|
+
cached_tokens: int = 0 # Tokens retrieved from cache
|
|
24
|
+
cache_creation_tokens: int = 0 # Tokens written to cache (Anthropic)
|
|
25
|
+
cache_read_tokens: int = 0 # Tokens read from cache (Anthropic)
|
|
26
|
+
reasoning_tokens: int = 0 # For reasoning models like o1/o3
|
|
27
|
+
cost_usd: float = 0.0
|
|
28
|
+
cost_breakdown: Optional[dict] = None # Detailed cost breakdown by token type
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class ChatRequest:
|
|
33
|
+
"""Unified request structure for chat completions."""
|
|
34
|
+
model: str
|
|
35
|
+
messages: List[Message]
|
|
36
|
+
temperature: float = 0.7
|
|
37
|
+
max_tokens: Optional[int] = None
|
|
38
|
+
stream: bool = False
|
|
39
|
+
top_p: float = 1.0
|
|
40
|
+
frequency_penalty: float = 0.0
|
|
41
|
+
presence_penalty: float = 0.0
|
|
42
|
+
stop: Optional[List[str]] = None
|
|
43
|
+
# Provider-specific extensions
|
|
44
|
+
reasoning_effort: Optional[str] = None # OpenAI o1/o3
|
|
45
|
+
extra_params: dict = field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ChatResponse:
|
|
50
|
+
"""Standard response from any provider."""
|
|
51
|
+
id: str
|
|
52
|
+
model: str
|
|
53
|
+
content: str
|
|
54
|
+
finish_reason: str
|
|
55
|
+
usage: Usage
|
|
56
|
+
provider: str
|
|
57
|
+
created_at: datetime
|
|
58
|
+
raw_response: dict # Original provider response for debugging
|
|
59
|
+
latency_ms: Optional[float] = None # Response latency in milliseconds
|