stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Google Gemini provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import GOOGLE_MODELS, PROVIDER_BASE_URLS
|
|
7
|
+
from ..exceptions import AuthenticationError
|
|
8
|
+
from .openai_compatible import OpenAICompatibleProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleProvider(OpenAICompatibleProvider):
|
|
12
|
+
"""Google Gemini provider using OpenAI-compatible API."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
api_key: Optional[str] = None,
|
|
17
|
+
config: dict = None
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Initialize Google Gemini provider.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_key: Google API key (defaults to GOOGLE_API_KEY env var)
|
|
24
|
+
config: Optional provider-specific configuration
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
AuthenticationError: If API key not provided
|
|
28
|
+
"""
|
|
29
|
+
api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
30
|
+
if not api_key:
|
|
31
|
+
raise AuthenticationError("google")
|
|
32
|
+
|
|
33
|
+
base_url = PROVIDER_BASE_URLS["google"]
|
|
34
|
+
super().__init__(api_key, base_url, GOOGLE_MODELS, config)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def provider_name(self) -> str:
|
|
38
|
+
"""Return provider name."""
|
|
39
|
+
return "google"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Grok (X.AI) provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import GROK_MODELS, PROVIDER_BASE_URLS
|
|
7
|
+
from ..exceptions import AuthenticationError
|
|
8
|
+
from .openai_compatible import OpenAICompatibleProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GrokProvider(OpenAICompatibleProvider):
|
|
12
|
+
"""Grok (X.AI) provider using OpenAI-compatible API."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
api_key: Optional[str] = None,
|
|
17
|
+
config: dict = None
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Initialize Grok provider.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_key: Grok API key (defaults to GROK_API_KEY env var)
|
|
24
|
+
config: Optional provider-specific configuration
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
AuthenticationError: If API key not provided
|
|
28
|
+
"""
|
|
29
|
+
api_key = api_key or os.getenv("GROK_API_KEY")
|
|
30
|
+
if not api_key:
|
|
31
|
+
raise AuthenticationError("grok")
|
|
32
|
+
|
|
33
|
+
base_url = PROVIDER_BASE_URLS["grok"]
|
|
34
|
+
super().__init__(api_key, base_url, GROK_MODELS, config)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def provider_name(self) -> str:
|
|
38
|
+
"""Return provider name."""
|
|
39
|
+
return "grok"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Groq provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import GROQ_MODELS, PROVIDER_BASE_URLS
|
|
7
|
+
from ..exceptions import AuthenticationError
|
|
8
|
+
from .openai_compatible import OpenAICompatibleProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GroqProvider(OpenAICompatibleProvider):
|
|
12
|
+
"""Groq provider using OpenAI-compatible API."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
api_key: Optional[str] = None,
|
|
17
|
+
config: dict = None
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Initialize Groq provider.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_key: Groq API key (defaults to GROQ_API_KEY env var)
|
|
24
|
+
config: Optional provider-specific configuration
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
AuthenticationError: If API key not provided
|
|
28
|
+
"""
|
|
29
|
+
api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
30
|
+
if not api_key:
|
|
31
|
+
raise AuthenticationError("groq")
|
|
32
|
+
|
|
33
|
+
base_url = PROVIDER_BASE_URLS["groq"]
|
|
34
|
+
super().__init__(api_key, base_url, GROQ_MODELS, config)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def provider_name(self) -> str:
|
|
38
|
+
"""Return provider name."""
|
|
39
|
+
return "groq"
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Ollama provider implementation for local models."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import OLLAMA_MODELS, PROVIDER_BASE_URLS
|
|
7
|
+
from .openai_compatible import OpenAICompatibleProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OllamaProvider(OpenAICompatibleProvider):
|
|
11
|
+
"""Ollama provider for local models using OpenAI-compatible API."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
api_key: Optional[str] = None,
|
|
16
|
+
config: dict = None
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize Ollama provider.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
api_key: Optional API key (Ollama typically doesn't require one)
|
|
23
|
+
config: Optional provider-specific configuration (can include base_url)
|
|
24
|
+
|
|
25
|
+
Note:
|
|
26
|
+
Ollama runs locally and typically doesn't require an API key.
|
|
27
|
+
Default base URL is http://localhost:11434/v1
|
|
28
|
+
"""
|
|
29
|
+
# Ollama doesn't require an API key, use placeholder
|
|
30
|
+
from ..api_key_helper import APIKeyHelper
|
|
31
|
+
api_key = APIKeyHelper.get_api_key("ollama", api_key) or "ollama"
|
|
32
|
+
|
|
33
|
+
# Allow custom base URL via config
|
|
34
|
+
base_url = PROVIDER_BASE_URLS["ollama"]
|
|
35
|
+
if config and "base_url" in config:
|
|
36
|
+
base_url = config["base_url"]
|
|
37
|
+
|
|
38
|
+
super().__init__(api_key, base_url, OLLAMA_MODELS, config)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def provider_name(self) -> str:
|
|
42
|
+
"""Return provider name."""
|
|
43
|
+
return "ollama"
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""OpenAI provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import AsyncIterator, List, Optional
|
|
6
|
+
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from ..config import OPENAI_MODELS, PROVIDER_CONSTRAINTS
|
|
10
|
+
from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
|
|
11
|
+
from ..models import ChatRequest, ChatResponse, Usage
|
|
12
|
+
from .base import BaseProvider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OpenAIProvider(BaseProvider):
|
|
16
|
+
"""OpenAI provider implementation with cost tracking."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
config: dict = None
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize OpenAI provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
|
28
|
+
config: Optional provider-specific configuration
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If API key not provided (with helpful setup instructions)
|
|
32
|
+
"""
|
|
33
|
+
from ..api_key_helper import get_api_key_or_error
|
|
34
|
+
api_key = get_api_key_or_error("openai", api_key)
|
|
35
|
+
super().__init__(api_key, config)
|
|
36
|
+
self._initialize_client()
|
|
37
|
+
|
|
38
|
+
def _initialize_client(self) -> None:
|
|
39
|
+
"""Initialize OpenAI async client."""
|
|
40
|
+
try:
|
|
41
|
+
self._client = AsyncOpenAI(api_key=self.api_key)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
raise ProviderAPIError(
|
|
44
|
+
f"Failed to initialize OpenAI client: {str(e)}",
|
|
45
|
+
"openai"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def provider_name(self) -> str:
|
|
50
|
+
"""Return provider name."""
|
|
51
|
+
return "openai"
|
|
52
|
+
|
|
53
|
+
def get_supported_models(self) -> List[str]:
|
|
54
|
+
"""Return list of supported OpenAI models."""
|
|
55
|
+
return list(OPENAI_MODELS.keys())
|
|
56
|
+
|
|
57
|
+
def supports_caching(self, model: str) -> bool:
|
|
58
|
+
"""Check if model supports prompt caching."""
|
|
59
|
+
model_info = OPENAI_MODELS.get(model, {})
|
|
60
|
+
return model_info.get("supports_caching", False)
|
|
61
|
+
|
|
62
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
63
|
+
"""
|
|
64
|
+
Execute chat completion request.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
request: Unified chat request
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Unified chat response with cost tracking
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
InvalidModelError: If model not supported
|
|
74
|
+
ProviderAPIError: If API call fails
|
|
75
|
+
"""
|
|
76
|
+
if not self.validate_model(request.model):
|
|
77
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
78
|
+
|
|
79
|
+
# Validate temperature constraints for OpenAI (0.0 to 2.0)
|
|
80
|
+
constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
|
|
81
|
+
self.validate_temperature(
|
|
82
|
+
request.temperature,
|
|
83
|
+
constraints.get("min_temperature", 0.0),
|
|
84
|
+
constraints.get("max_temperature", 2.0)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Build OpenAI-specific request parameters
|
|
88
|
+
messages = []
|
|
89
|
+
for msg in request.messages:
|
|
90
|
+
message_dict = {"role": msg.role, "content": msg.content}
|
|
91
|
+
# Add cache_control if present and model supports caching
|
|
92
|
+
if msg.cache_control and self.supports_caching(request.model):
|
|
93
|
+
message_dict["cache_control"] = msg.cache_control
|
|
94
|
+
messages.append(message_dict)
|
|
95
|
+
|
|
96
|
+
openai_params = {
|
|
97
|
+
"model": request.model,
|
|
98
|
+
"messages": messages,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# Check if model is a reasoning model (o-series)
|
|
102
|
+
model_info = OPENAI_MODELS.get(request.model, {})
|
|
103
|
+
is_reasoning_model = model_info.get("reasoning_model", False)
|
|
104
|
+
|
|
105
|
+
# Also check if model name starts with o1, o3, gpt-5, or just 'o' followed by a digit
|
|
106
|
+
# This catches variants like o1-preview, o1-2024-12-17, o3-mini, gpt-5, etc.
|
|
107
|
+
if not is_reasoning_model and request.model:
|
|
108
|
+
model_lower = request.model.lower()
|
|
109
|
+
# Match: o1*, o3*, gpt-5*, "reasoning", or o followed by digit
|
|
110
|
+
is_reasoning_model = (
|
|
111
|
+
model_lower.startswith("o1") or
|
|
112
|
+
model_lower.startswith("o3") or
|
|
113
|
+
model_lower.startswith("gpt-5") or
|
|
114
|
+
"reasoning" in model_lower or
|
|
115
|
+
(model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Only add these parameters for non-reasoning models
|
|
119
|
+
# Reasoning models like o1, o1-mini, o3-mini don't support temperature/top_p/penalties
|
|
120
|
+
if not is_reasoning_model:
|
|
121
|
+
openai_params["temperature"] = request.temperature
|
|
122
|
+
openai_params["top_p"] = request.top_p
|
|
123
|
+
openai_params["frequency_penalty"] = request.frequency_penalty
|
|
124
|
+
openai_params["presence_penalty"] = request.presence_penalty
|
|
125
|
+
|
|
126
|
+
# Add optional parameters
|
|
127
|
+
if request.max_tokens:
|
|
128
|
+
openai_params["max_tokens"] = request.max_tokens
|
|
129
|
+
if request.stop:
|
|
130
|
+
openai_params["stop"] = request.stop
|
|
131
|
+
|
|
132
|
+
# Add reasoning_effort for o-series models
|
|
133
|
+
if request.reasoning_effort and "o" in request.model:
|
|
134
|
+
openai_params["reasoning_effort"] = request.reasoning_effort
|
|
135
|
+
|
|
136
|
+
# Add any extra params
|
|
137
|
+
if request.extra_params:
|
|
138
|
+
openai_params.update(request.extra_params)
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
# Make API request
|
|
142
|
+
raw_response = await self._client.chat.completions.create(**openai_params)
|
|
143
|
+
# Normalize and return
|
|
144
|
+
return self._normalize_response(raw_response.model_dump())
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise ProviderAPIError(
|
|
147
|
+
f"Chat completion failed: {str(e)}",
|
|
148
|
+
self.provider_name
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def chat_completion_stream(
|
|
152
|
+
self, request: ChatRequest
|
|
153
|
+
) -> AsyncIterator[ChatResponse]:
|
|
154
|
+
"""
|
|
155
|
+
Execute streaming chat completion request.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
request: Unified chat request
|
|
159
|
+
|
|
160
|
+
Yields:
|
|
161
|
+
Unified chat response chunks
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
InvalidModelError: If model not supported
|
|
165
|
+
ProviderAPIError: If API call fails
|
|
166
|
+
"""
|
|
167
|
+
if not self.validate_model(request.model):
|
|
168
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
169
|
+
|
|
170
|
+
# Build request parameters
|
|
171
|
+
openai_params = {
|
|
172
|
+
"model": request.model,
|
|
173
|
+
"messages": [
|
|
174
|
+
{"role": msg.role, "content": msg.content}
|
|
175
|
+
for msg in request.messages
|
|
176
|
+
],
|
|
177
|
+
"stream": True,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Check if model is a reasoning model
|
|
181
|
+
model_info = OPENAI_MODELS.get(request.model, {})
|
|
182
|
+
is_reasoning_model = model_info.get("reasoning_model", False)
|
|
183
|
+
|
|
184
|
+
# Also check if model name starts with o1, o3, gpt-5, or just 'o' followed by a digit
|
|
185
|
+
if not is_reasoning_model and request.model:
|
|
186
|
+
model_lower = request.model.lower()
|
|
187
|
+
is_reasoning_model = (
|
|
188
|
+
model_lower.startswith("o1") or
|
|
189
|
+
model_lower.startswith("o3") or
|
|
190
|
+
model_lower.startswith("gpt-5") or
|
|
191
|
+
"reasoning" in model_lower or
|
|
192
|
+
(model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Only add temperature for non-reasoning models
|
|
196
|
+
if not is_reasoning_model:
|
|
197
|
+
openai_params["temperature"] = request.temperature
|
|
198
|
+
|
|
199
|
+
if request.max_tokens:
|
|
200
|
+
openai_params["max_tokens"] = request.max_tokens
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
stream = await self._client.chat.completions.create(**openai_params)
|
|
204
|
+
|
|
205
|
+
async for chunk in stream:
|
|
206
|
+
chunk_dict = chunk.model_dump()
|
|
207
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
208
|
+
yield self._normalize_stream_chunk(chunk_dict)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
raise ProviderAPIError(
|
|
211
|
+
f"Streaming chat completion failed: {str(e)}",
|
|
212
|
+
self.provider_name
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def _normalize_response(self, raw_response: dict) -> ChatResponse:
|
|
216
|
+
"""
|
|
217
|
+
Convert OpenAI response to unified format.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
raw_response: Raw OpenAI API response
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Normalized ChatResponse with cost
|
|
224
|
+
"""
|
|
225
|
+
choice = raw_response["choices"][0]
|
|
226
|
+
usage_dict = raw_response.get("usage", {})
|
|
227
|
+
|
|
228
|
+
# Extract token usage
|
|
229
|
+
prompt_details = usage_dict.get("prompt_tokens_details", {})
|
|
230
|
+
usage = Usage(
|
|
231
|
+
prompt_tokens=usage_dict.get("prompt_tokens", 0),
|
|
232
|
+
completion_tokens=usage_dict.get("completion_tokens", 0),
|
|
233
|
+
total_tokens=usage_dict.get("total_tokens", 0),
|
|
234
|
+
cached_tokens=prompt_details.get("cached_tokens", 0),
|
|
235
|
+
cache_creation_tokens=prompt_details.get("cache_creation_input_tokens", 0),
|
|
236
|
+
cache_read_tokens=prompt_details.get("cached_tokens", 0),
|
|
237
|
+
reasoning_tokens=usage_dict.get("completion_tokens_details", {}).get(
|
|
238
|
+
"reasoning_tokens", 0
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Calculate cost including cache costs
|
|
243
|
+
base_cost = self._calculate_cost(usage, raw_response["model"])
|
|
244
|
+
cache_cost = self._calculate_cache_cost(
|
|
245
|
+
usage.cache_creation_tokens,
|
|
246
|
+
usage.cache_read_tokens,
|
|
247
|
+
raw_response["model"]
|
|
248
|
+
)
|
|
249
|
+
usage.cost_usd = base_cost + cache_cost
|
|
250
|
+
|
|
251
|
+
# Add cost breakdown
|
|
252
|
+
if usage.cache_creation_tokens > 0 or usage.cache_read_tokens > 0:
|
|
253
|
+
usage.cost_breakdown = {
|
|
254
|
+
"base_cost": base_cost,
|
|
255
|
+
"cache_cost": cache_cost,
|
|
256
|
+
"total_cost": usage.cost_usd,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
return ChatResponse(
|
|
260
|
+
id=raw_response["id"],
|
|
261
|
+
model=raw_response["model"],
|
|
262
|
+
content=choice["message"]["content"] or "",
|
|
263
|
+
finish_reason=choice["finish_reason"],
|
|
264
|
+
usage=usage,
|
|
265
|
+
provider=self.provider_name,
|
|
266
|
+
created_at=datetime.fromtimestamp(raw_response["created"]),
|
|
267
|
+
raw_response=raw_response,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def _normalize_stream_chunk(self, chunk_dict: dict) -> ChatResponse:
|
|
271
|
+
"""Normalize streaming chunk to ChatResponse format."""
|
|
272
|
+
choice = chunk_dict["choices"][0]
|
|
273
|
+
content = choice["delta"].get("content", "")
|
|
274
|
+
|
|
275
|
+
return ChatResponse(
|
|
276
|
+
id=chunk_dict["id"],
|
|
277
|
+
model=chunk_dict["model"],
|
|
278
|
+
content=content,
|
|
279
|
+
finish_reason=choice.get("finish_reason", ""),
|
|
280
|
+
usage=Usage(
|
|
281
|
+
prompt_tokens=0,
|
|
282
|
+
completion_tokens=0,
|
|
283
|
+
total_tokens=0
|
|
284
|
+
),
|
|
285
|
+
provider=self.provider_name,
|
|
286
|
+
created_at=datetime.fromtimestamp(chunk_dict["created"]),
|
|
287
|
+
raw_response=chunk_dict,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def _calculate_cost(self, usage: Usage, model: str) -> float:
|
|
291
|
+
"""
|
|
292
|
+
Calculate cost in USD based on token usage (excluding cache costs).
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
usage: Token usage information
|
|
296
|
+
model: Model name used
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Cost in USD
|
|
300
|
+
"""
|
|
301
|
+
model_info = OPENAI_MODELS.get(model, {})
|
|
302
|
+
cost_input = model_info.get("cost_input", 0.0)
|
|
303
|
+
cost_output = model_info.get("cost_output", 0.0)
|
|
304
|
+
|
|
305
|
+
# Calculate non-cached prompt tokens
|
|
306
|
+
non_cached_prompt_tokens = usage.prompt_tokens - usage.cache_read_tokens
|
|
307
|
+
|
|
308
|
+
# Costs are per 1M tokens
|
|
309
|
+
input_cost = (non_cached_prompt_tokens / 1_000_000) * cost_input
|
|
310
|
+
output_cost = (usage.completion_tokens / 1_000_000) * cost_output
|
|
311
|
+
|
|
312
|
+
return input_cost + output_cost
|
|
313
|
+
|
|
314
|
+
def _calculate_cache_cost(
|
|
315
|
+
self,
|
|
316
|
+
cache_creation_tokens: int,
|
|
317
|
+
cache_read_tokens: int,
|
|
318
|
+
model: str
|
|
319
|
+
) -> float:
|
|
320
|
+
"""
|
|
321
|
+
Calculate cost for cached tokens.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
cache_creation_tokens: Number of tokens written to cache
|
|
325
|
+
cache_read_tokens: Number of tokens read from cache
|
|
326
|
+
model: Model name used
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Cost in USD for cache operations
|
|
330
|
+
"""
|
|
331
|
+
model_info = OPENAI_MODELS.get(model, {})
|
|
332
|
+
|
|
333
|
+
# Check if model supports caching
|
|
334
|
+
if not model_info.get("supports_caching", False):
|
|
335
|
+
return 0.0
|
|
336
|
+
|
|
337
|
+
cost_cache_write = model_info.get("cost_cache_write", 0.0)
|
|
338
|
+
cost_cache_read = model_info.get("cost_cache_read", 0.0)
|
|
339
|
+
|
|
340
|
+
# Costs are per 1M tokens
|
|
341
|
+
write_cost = (cache_creation_tokens / 1_000_000) * cost_cache_write
|
|
342
|
+
read_cost = (cache_read_tokens / 1_000_000) * cost_cache_read
|
|
343
|
+
|
|
344
|
+
return write_cost + read_cost
|