stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Anthropic provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import AsyncIterator, List, Optional
|
|
6
|
+
|
|
7
|
+
from anthropic import AsyncAnthropic
|
|
8
|
+
|
|
9
|
+
from ..config import ANTHROPIC_MODELS, PROVIDER_CONSTRAINTS
|
|
10
|
+
from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
|
|
11
|
+
from ..models import ChatRequest, ChatResponse, Usage
|
|
12
|
+
from .base import BaseProvider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AnthropicProvider(BaseProvider):
|
|
16
|
+
"""Anthropic provider implementation with Messages API."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
config: dict = None
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize Anthropic provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
|
28
|
+
config: Optional provider-specific configuration
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
AuthenticationError: If API key not provided
|
|
32
|
+
"""
|
|
33
|
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
34
|
+
if not api_key:
|
|
35
|
+
raise AuthenticationError("anthropic")
|
|
36
|
+
super().__init__(api_key, config)
|
|
37
|
+
self._initialize_client()
|
|
38
|
+
|
|
39
|
+
def _initialize_client(self) -> None:
|
|
40
|
+
"""Initialize Anthropic async client."""
|
|
41
|
+
try:
|
|
42
|
+
self._client = AsyncAnthropic(api_key=self.api_key)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise ProviderAPIError(
|
|
45
|
+
f"Failed to initialize Anthropic client: {str(e)}",
|
|
46
|
+
"anthropic"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def provider_name(self) -> str:
|
|
51
|
+
"""Return provider name."""
|
|
52
|
+
return "anthropic"
|
|
53
|
+
|
|
54
|
+
def get_supported_models(self) -> List[str]:
|
|
55
|
+
"""Return list of supported Anthropic models."""
|
|
56
|
+
return list(ANTHROPIC_MODELS.keys())
|
|
57
|
+
|
|
58
|
+
def supports_caching(self, model: str) -> bool:
|
|
59
|
+
"""Check if model supports prompt caching."""
|
|
60
|
+
model_info = ANTHROPIC_MODELS.get(model, {})
|
|
61
|
+
return model_info.get("supports_caching", False)
|
|
62
|
+
|
|
63
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
64
|
+
"""
|
|
65
|
+
Execute chat completion request using Messages API.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
request: Unified chat request
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Unified chat response with cost tracking
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
InvalidModelError: If model not supported
|
|
75
|
+
ProviderAPIError: If API call fails
|
|
76
|
+
"""
|
|
77
|
+
if not self.validate_model(request.model):
|
|
78
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
79
|
+
|
|
80
|
+
# Validate temperature constraints for Anthropic (0.0 to 1.0)
|
|
81
|
+
constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
|
|
82
|
+
self.validate_temperature(
|
|
83
|
+
request.temperature,
|
|
84
|
+
constraints.get("min_temperature", 0.0),
|
|
85
|
+
constraints.get("max_temperature", 1.0)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Convert messages to Anthropic format
|
|
89
|
+
# Anthropic requires system message separate from messages array
|
|
90
|
+
system_message = None
|
|
91
|
+
messages = []
|
|
92
|
+
|
|
93
|
+
for msg in request.messages:
|
|
94
|
+
if msg.role == "system":
|
|
95
|
+
system_message = msg.content
|
|
96
|
+
else:
|
|
97
|
+
message_dict = {"role": msg.role, "content": msg.content}
|
|
98
|
+
# Add cache_control if present and model supports caching
|
|
99
|
+
if msg.cache_control and self.supports_caching(request.model):
|
|
100
|
+
message_dict["cache_control"] = msg.cache_control
|
|
101
|
+
messages.append(message_dict)
|
|
102
|
+
|
|
103
|
+
# Build Anthropic-specific request parameters
|
|
104
|
+
anthropic_params = {
|
|
105
|
+
"model": request.model,
|
|
106
|
+
"messages": messages,
|
|
107
|
+
"max_tokens": request.max_tokens or 4096, # Anthropic requires max_tokens
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Anthropic only allows one of temperature or top_p
|
|
111
|
+
# Prefer temperature if it's not the default, otherwise use top_p if it's not default
|
|
112
|
+
# Default temperature is 0.7, default top_p is 1.0
|
|
113
|
+
if request.temperature != 0.7:
|
|
114
|
+
# Temperature was explicitly set, use it
|
|
115
|
+
anthropic_params["temperature"] = request.temperature
|
|
116
|
+
elif request.top_p != 1.0:
|
|
117
|
+
# top_p was explicitly set (not default), use it
|
|
118
|
+
anthropic_params["top_p"] = request.top_p
|
|
119
|
+
else:
|
|
120
|
+
# Both are defaults, use temperature
|
|
121
|
+
anthropic_params["temperature"] = request.temperature
|
|
122
|
+
|
|
123
|
+
# Add system message if present
|
|
124
|
+
if system_message:
|
|
125
|
+
anthropic_params["system"] = system_message
|
|
126
|
+
|
|
127
|
+
# Add optional parameters
|
|
128
|
+
if request.stop:
|
|
129
|
+
anthropic_params["stop_sequences"] = request.stop
|
|
130
|
+
|
|
131
|
+
# Add any extra params
|
|
132
|
+
if request.extra_params:
|
|
133
|
+
anthropic_params.update(request.extra_params)
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
# Make API request
|
|
137
|
+
raw_response = await self._client.messages.create(**anthropic_params)
|
|
138
|
+
# Normalize and return
|
|
139
|
+
return self._normalize_response(raw_response.model_dump())
|
|
140
|
+
except Exception as e:
|
|
141
|
+
raise ProviderAPIError(
|
|
142
|
+
f"Chat completion failed: {str(e)}",
|
|
143
|
+
self.provider_name
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
async def chat_completion_stream(
|
|
147
|
+
self, request: ChatRequest
|
|
148
|
+
) -> AsyncIterator[ChatResponse]:
|
|
149
|
+
"""
|
|
150
|
+
Execute streaming chat completion request.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
request: Unified chat request
|
|
154
|
+
|
|
155
|
+
Yields:
|
|
156
|
+
Unified chat response chunks
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
InvalidModelError: If model not supported
|
|
160
|
+
ProviderAPIError: If API call fails
|
|
161
|
+
"""
|
|
162
|
+
if not self.validate_model(request.model):
|
|
163
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
164
|
+
|
|
165
|
+
# Validate temperature constraints for Anthropic (0.0 to 1.0)
|
|
166
|
+
constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
|
|
167
|
+
self.validate_temperature(
|
|
168
|
+
request.temperature,
|
|
169
|
+
constraints.get("min_temperature", 0.0),
|
|
170
|
+
constraints.get("max_temperature", 1.0)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Convert messages to Anthropic format
|
|
174
|
+
system_message = None
|
|
175
|
+
messages = []
|
|
176
|
+
|
|
177
|
+
for msg in request.messages:
|
|
178
|
+
if msg.role == "system":
|
|
179
|
+
system_message = msg.content
|
|
180
|
+
else:
|
|
181
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
182
|
+
|
|
183
|
+
# Build request parameters
|
|
184
|
+
anthropic_params = {
|
|
185
|
+
"model": request.model,
|
|
186
|
+
"messages": messages,
|
|
187
|
+
"temperature": request.temperature,
|
|
188
|
+
"max_tokens": request.max_tokens or 4096,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if system_message:
|
|
192
|
+
anthropic_params["system"] = system_message
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
async with self._client.messages.stream(**anthropic_params) as stream:
|
|
196
|
+
async for chunk in stream.text_stream:
|
|
197
|
+
yield self._normalize_stream_chunk(chunk)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
raise ProviderAPIError(
|
|
200
|
+
f"Streaming chat completion failed: {str(e)}",
|
|
201
|
+
self.provider_name
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def _normalize_response(self, raw_response: dict) -> ChatResponse:
|
|
205
|
+
"""
|
|
206
|
+
Convert Anthropic response to unified format.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
raw_response: Raw Anthropic API response
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Normalized ChatResponse with cost
|
|
213
|
+
"""
|
|
214
|
+
# Extract content from response
|
|
215
|
+
content = ""
|
|
216
|
+
if raw_response.get("content"):
|
|
217
|
+
for block in raw_response["content"]:
|
|
218
|
+
if block.get("type") == "text":
|
|
219
|
+
content += block.get("text", "")
|
|
220
|
+
|
|
221
|
+
# Extract token usage
|
|
222
|
+
usage_dict = raw_response.get("usage", {})
|
|
223
|
+
usage = Usage(
|
|
224
|
+
prompt_tokens=usage_dict.get("input_tokens", 0),
|
|
225
|
+
completion_tokens=usage_dict.get("output_tokens", 0),
|
|
226
|
+
total_tokens=usage_dict.get("input_tokens", 0) + usage_dict.get("output_tokens", 0),
|
|
227
|
+
cache_creation_tokens=usage_dict.get("cache_creation_input_tokens", 0),
|
|
228
|
+
cache_read_tokens=usage_dict.get("cache_read_input_tokens", 0),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Calculate cost including cache costs
|
|
232
|
+
base_cost = self._calculate_cost(usage, raw_response["model"])
|
|
233
|
+
cache_cost = self._calculate_cache_cost(
|
|
234
|
+
usage.cache_creation_tokens,
|
|
235
|
+
usage.cache_read_tokens,
|
|
236
|
+
raw_response["model"]
|
|
237
|
+
)
|
|
238
|
+
usage.cost_usd = base_cost + cache_cost
|
|
239
|
+
|
|
240
|
+
# Add cost breakdown
|
|
241
|
+
if usage.cache_creation_tokens > 0 or usage.cache_read_tokens > 0:
|
|
242
|
+
usage.cost_breakdown = {
|
|
243
|
+
"base_cost": base_cost,
|
|
244
|
+
"cache_cost": cache_cost,
|
|
245
|
+
"total_cost": usage.cost_usd,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
return ChatResponse(
|
|
249
|
+
id=raw_response["id"],
|
|
250
|
+
model=raw_response["model"],
|
|
251
|
+
content=content,
|
|
252
|
+
finish_reason=raw_response.get("stop_reason", "stop"),
|
|
253
|
+
usage=usage,
|
|
254
|
+
provider=self.provider_name,
|
|
255
|
+
created_at=datetime.now(), # Anthropic doesn't provide timestamp
|
|
256
|
+
raw_response=raw_response,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _normalize_stream_chunk(self, chunk: str) -> ChatResponse:
|
|
260
|
+
"""Normalize streaming chunk to ChatResponse format."""
|
|
261
|
+
return ChatResponse(
|
|
262
|
+
id="",
|
|
263
|
+
model="",
|
|
264
|
+
content=chunk,
|
|
265
|
+
finish_reason="",
|
|
266
|
+
usage=Usage(
|
|
267
|
+
prompt_tokens=0,
|
|
268
|
+
completion_tokens=0,
|
|
269
|
+
total_tokens=0
|
|
270
|
+
),
|
|
271
|
+
provider=self.provider_name,
|
|
272
|
+
created_at=datetime.now(),
|
|
273
|
+
raw_response={},
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def _calculate_cost(self, usage: Usage, model: str) -> float:
|
|
277
|
+
"""
|
|
278
|
+
Calculate cost in USD based on token usage (excluding cache costs).
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
usage: Token usage information
|
|
282
|
+
model: Model name used
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Cost in USD
|
|
286
|
+
"""
|
|
287
|
+
model_info = ANTHROPIC_MODELS.get(model, {})
|
|
288
|
+
cost_input = model_info.get("cost_input", 0.0)
|
|
289
|
+
cost_output = model_info.get("cost_output", 0.0)
|
|
290
|
+
|
|
291
|
+
# Calculate non-cached prompt tokens
|
|
292
|
+
non_cached_prompt_tokens = usage.prompt_tokens - usage.cache_read_tokens
|
|
293
|
+
|
|
294
|
+
# Costs are per 1M tokens
|
|
295
|
+
input_cost = (non_cached_prompt_tokens / 1_000_000) * cost_input
|
|
296
|
+
output_cost = (usage.completion_tokens / 1_000_000) * cost_output
|
|
297
|
+
|
|
298
|
+
return input_cost + output_cost
|
|
299
|
+
|
|
300
|
+
def _calculate_cache_cost(
|
|
301
|
+
self,
|
|
302
|
+
cache_creation_tokens: int,
|
|
303
|
+
cache_read_tokens: int,
|
|
304
|
+
model: str
|
|
305
|
+
) -> float:
|
|
306
|
+
"""
|
|
307
|
+
Calculate cost for cached tokens.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
cache_creation_tokens: Number of tokens written to cache
|
|
311
|
+
cache_read_tokens: Number of tokens read from cache
|
|
312
|
+
model: Model name used
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Cost in USD for cache operations
|
|
316
|
+
"""
|
|
317
|
+
model_info = ANTHROPIC_MODELS.get(model, {})
|
|
318
|
+
|
|
319
|
+
# Check if model supports caching
|
|
320
|
+
if not model_info.get("supports_caching", False):
|
|
321
|
+
return 0.0
|
|
322
|
+
|
|
323
|
+
cost_cache_write = model_info.get("cost_cache_write", 0.0)
|
|
324
|
+
cost_cache_read = model_info.get("cost_cache_read", 0.0)
|
|
325
|
+
|
|
326
|
+
# Costs are per 1M tokens
|
|
327
|
+
write_cost = (cache_creation_tokens / 1_000_000) * cost_cache_write
|
|
328
|
+
read_cost = (cache_read_tokens / 1_000_000) * cost_cache_read
|
|
329
|
+
|
|
330
|
+
return write_cost + read_cost
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Abstract base class for LLM providers."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import AsyncIterator, List, Optional
|
|
6
|
+
|
|
7
|
+
from ..models import ChatRequest, ChatResponse, Usage
|
|
8
|
+
from ..exceptions import ValidationError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseProvider(ABC):
|
|
12
|
+
"""Abstract base class that all LLM providers must implement."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, api_key: str, config: dict = None):
|
|
15
|
+
"""
|
|
16
|
+
Initialize provider with API key and optional configuration.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
api_key: Provider API key
|
|
20
|
+
config: Optional provider-specific configuration
|
|
21
|
+
"""
|
|
22
|
+
self.api_key = api_key
|
|
23
|
+
self.config = config or {}
|
|
24
|
+
self._client = None
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _initialize_client(self) -> None:
|
|
28
|
+
"""Initialize the provider-specific client library."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
33
|
+
"""
|
|
34
|
+
Execute a chat completion request.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
request: Unified chat request
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Unified chat response
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
InvalidModelError: If model not supported
|
|
44
|
+
ProviderAPIError: If API call fails
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
async def chat_completion_stream(
|
|
50
|
+
self, request: ChatRequest
|
|
51
|
+
) -> AsyncIterator[ChatResponse]:
|
|
52
|
+
"""
|
|
53
|
+
Execute a streaming chat completion request.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
request: Unified chat request with stream=True
|
|
57
|
+
|
|
58
|
+
Yields:
|
|
59
|
+
Unified chat response chunks
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
InvalidModelError: If model not supported
|
|
63
|
+
ProviderAPIError: If API call fails
|
|
64
|
+
"""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def chat_completion_sync(self, request: ChatRequest) -> ChatResponse:
|
|
68
|
+
"""
|
|
69
|
+
Synchronous wrapper for chat_completion.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
request: Unified chat request
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Unified chat response
|
|
76
|
+
"""
|
|
77
|
+
return asyncio.run(self.chat_completion(request))
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def _normalize_response(self, raw_response: dict) -> ChatResponse:
|
|
81
|
+
"""
|
|
82
|
+
Convert provider-specific response to unified format.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
raw_response: Raw response from provider API
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Normalized ChatResponse
|
|
89
|
+
"""
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def _calculate_cost(self, usage: Usage, model: str) -> float:
|
|
94
|
+
"""
|
|
95
|
+
Calculate cost for the request based on token usage.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
usage: Token usage information
|
|
99
|
+
model: Model name used
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Cost in USD
|
|
103
|
+
"""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def provider_name(self) -> str:
|
|
109
|
+
"""Return the provider name (e.g., 'openai', 'anthropic')."""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def get_supported_models(self) -> List[str]:
|
|
114
|
+
"""
|
|
115
|
+
Return list of models supported by this provider.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of model names
|
|
119
|
+
"""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def validate_model(self, model: str) -> bool:
|
|
123
|
+
"""
|
|
124
|
+
Check if model is supported by this provider.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
model: Model name to validate
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
True if supported, False otherwise
|
|
131
|
+
"""
|
|
132
|
+
return model in self.get_supported_models()
|
|
133
|
+
|
|
134
|
+
def supports_caching(self, model: str) -> bool:
|
|
135
|
+
"""
|
|
136
|
+
Check if model supports prompt caching.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
model: Model name to check
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if model supports prompt caching, False otherwise
|
|
143
|
+
"""
|
|
144
|
+
# To be implemented by providers that support caching
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
def _calculate_cache_cost(
|
|
148
|
+
self,
|
|
149
|
+
cache_creation_tokens: int,
|
|
150
|
+
cache_read_tokens: int,
|
|
151
|
+
model: str
|
|
152
|
+
) -> float:
|
|
153
|
+
"""
|
|
154
|
+
Calculate cost for cached tokens.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
cache_creation_tokens: Number of tokens written to cache
|
|
158
|
+
cache_read_tokens: Number of tokens read from cache
|
|
159
|
+
model: Model name used
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Cost in USD for cache operations
|
|
163
|
+
"""
|
|
164
|
+
# Base implementation returns 0, override in providers that support caching
|
|
165
|
+
return 0.0
|
|
166
|
+
|
|
167
|
+
def validate_temperature(self, temperature: float, min_temp: float = 0.0, max_temp: float = 2.0) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Validate temperature parameter is within provider constraints.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
temperature: Temperature value to validate
|
|
173
|
+
min_temp: Minimum allowed temperature (provider-specific)
|
|
174
|
+
max_temp: Maximum allowed temperature (provider-specific)
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValidationError: If temperature is out of range
|
|
178
|
+
"""
|
|
179
|
+
if not (min_temp <= temperature <= max_temp):
|
|
180
|
+
raise ValidationError(
|
|
181
|
+
f"{self.provider_name} temperature must be between {min_temp} and {max_temp}, "
|
|
182
|
+
f"got {temperature}"
|
|
183
|
+
)
|