stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
"""AWS Bedrock provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import AsyncIterator, List, Optional
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import aioboto3
|
|
10
|
+
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"aioboto3 is required for AWS Bedrock async support. "
|
|
14
|
+
"Install with: pip install aioboto3>=12.0.0"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from ..config import BEDROCK_MODELS, PROVIDER_CONSTRAINTS
|
|
18
|
+
from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
|
|
19
|
+
from ..models import ChatRequest, ChatResponse, Usage
|
|
20
|
+
from .base import BaseProvider
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BedrockProvider(BaseProvider):
|
|
24
|
+
"""AWS Bedrock provider implementation using aioboto3 for async support."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_key: Optional[str] = None, # For compatibility with LLMClient (AWS_BEARER_TOKEN_BEDROCK)
|
|
29
|
+
aws_access_key_id: Optional[str] = None,
|
|
30
|
+
aws_secret_access_key: Optional[str] = None,
|
|
31
|
+
aws_session_token: Optional[str] = None,
|
|
32
|
+
region_name: Optional[str] = None,
|
|
33
|
+
config: dict = None
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Initialize AWS Bedrock provider.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
api_key: AWS bearer token (defaults to AWS_BEARER_TOKEN_BEDROCK env var)
|
|
40
|
+
or for compatibility with LLMClient interface
|
|
41
|
+
aws_access_key_id: AWS access key (defaults to AWS_ACCESS_KEY_ID env var)
|
|
42
|
+
aws_secret_access_key: AWS secret key (defaults to AWS_SECRET_ACCESS_KEY env var)
|
|
43
|
+
aws_session_token: AWS session token (defaults to AWS_SESSION_TOKEN env var)
|
|
44
|
+
region_name: AWS region (defaults to AWS_DEFAULT_REGION or us-east-1)
|
|
45
|
+
config: Optional provider-specific configuration
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ValueError: If AWS credentials are not available (with helpful setup instructions)
|
|
49
|
+
"""
|
|
50
|
+
# AWS Bedrock supports multiple authentication methods:
|
|
51
|
+
# 1. Bearer token (AWS_BEARER_TOKEN_BEDROCK)
|
|
52
|
+
# 2. Access key + secret key (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY)
|
|
53
|
+
# 3. IAM roles (when running on AWS infrastructure)
|
|
54
|
+
# 4. ~/.aws/credentials file
|
|
55
|
+
|
|
56
|
+
# Check for bearer token first (simplest method)
|
|
57
|
+
bearer_token = api_key or os.getenv("AWS_BEARER_TOKEN_BEDROCK")
|
|
58
|
+
|
|
59
|
+
# Check for access key credentials
|
|
60
|
+
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
|
61
|
+
self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
62
|
+
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
|
63
|
+
self.region_name = region_name or os.getenv("AWS_DEFAULT_REGION", "us-east-1")
|
|
64
|
+
|
|
65
|
+
# Use APIKeyHelper for better error messages if no credentials found
|
|
66
|
+
if not bearer_token and not (self.aws_access_key_id and self.aws_secret_access_key):
|
|
67
|
+
from ..api_key_helper import get_api_key_or_error
|
|
68
|
+
try:
|
|
69
|
+
get_api_key_or_error("bedrock", bearer_token)
|
|
70
|
+
except ValueError:
|
|
71
|
+
# Allow to proceed if using IAM roles or ~/.aws/credentials
|
|
72
|
+
# boto3 will handle the credential chain
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
# BaseProvider expects api_key, so we'll use access_key_id as a placeholder
|
|
76
|
+
# (Bedrock doesn't use API keys like other providers)
|
|
77
|
+
super().__init__(self.aws_access_key_id or "aws-credentials", config)
|
|
78
|
+
self._initialize_client()
|
|
79
|
+
|
|
80
|
+
def _initialize_client(self) -> None:
|
|
81
|
+
"""Initialize AWS Bedrock session for async client creation."""
|
|
82
|
+
try:
|
|
83
|
+
# Create aioboto3 session with explicit credentials if provided
|
|
84
|
+
session_params = {"region_name": self.region_name}
|
|
85
|
+
if self.aws_access_key_id and self.aws_secret_access_key:
|
|
86
|
+
session_params["aws_access_key_id"] = self.aws_access_key_id
|
|
87
|
+
session_params["aws_secret_access_key"] = self.aws_secret_access_key
|
|
88
|
+
if self.aws_session_token:
|
|
89
|
+
session_params["aws_session_token"] = self.aws_session_token
|
|
90
|
+
|
|
91
|
+
# Store session for async client creation
|
|
92
|
+
# aioboto3 clients must be created within async context
|
|
93
|
+
self._session = aioboto3.Session(**session_params)
|
|
94
|
+
self._client = None # Will be created in async context
|
|
95
|
+
|
|
96
|
+
except NoCredentialsError:
|
|
97
|
+
raise AuthenticationError(
|
|
98
|
+
"AWS credentials not found. Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
|
|
99
|
+
"environment variables or configure ~/.aws/credentials"
|
|
100
|
+
)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise ProviderAPIError(
|
|
103
|
+
f"Failed to initialize AWS Bedrock session: {str(e)}",
|
|
104
|
+
"bedrock"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def provider_name(self) -> str:
|
|
109
|
+
"""Return provider name."""
|
|
110
|
+
return "bedrock"
|
|
111
|
+
|
|
112
|
+
def get_supported_models(self) -> List[str]:
|
|
113
|
+
"""Return list of supported Bedrock models."""
|
|
114
|
+
return list(BEDROCK_MODELS.keys())
|
|
115
|
+
|
|
116
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
117
|
+
"""
|
|
118
|
+
Execute chat completion request using Bedrock.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
request: Unified chat request
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Unified chat response with cost tracking
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
InvalidModelError: If model not supported
|
|
128
|
+
ProviderAPIError: If API call fails
|
|
129
|
+
"""
|
|
130
|
+
if not self.validate_model(request.model):
|
|
131
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
132
|
+
|
|
133
|
+
# Validate temperature constraints for Bedrock (0.0 to 1.0)
|
|
134
|
+
constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
|
|
135
|
+
self.validate_temperature(
|
|
136
|
+
request.temperature,
|
|
137
|
+
constraints.get("min_temperature", 0.0),
|
|
138
|
+
constraints.get("max_temperature", 1.0)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Build request body based on model family
|
|
142
|
+
body = self._build_request_body(request)
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# Create async client and invoke Bedrock model
|
|
146
|
+
async with self._session.client("bedrock-runtime") as client:
|
|
147
|
+
response = await client.invoke_model(
|
|
148
|
+
modelId=request.model,
|
|
149
|
+
contentType="application/json",
|
|
150
|
+
accept="application/json",
|
|
151
|
+
body=json.dumps(body)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Parse response - aioboto3 returns StreamingBody
|
|
155
|
+
response_body_bytes = await response["body"].read()
|
|
156
|
+
response_body = json.loads(response_body_bytes)
|
|
157
|
+
|
|
158
|
+
# Normalize response based on model family
|
|
159
|
+
return self._normalize_response(response_body, request.model)
|
|
160
|
+
|
|
161
|
+
except ClientError as e:
|
|
162
|
+
error_code = e.response["Error"]["Code"]
|
|
163
|
+
error_message = e.response["Error"]["Message"]
|
|
164
|
+
raise ProviderAPIError(
|
|
165
|
+
f"Bedrock API error ({error_code}): {error_message}",
|
|
166
|
+
self.provider_name
|
|
167
|
+
)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
raise ProviderAPIError(
|
|
170
|
+
f"Chat completion failed: {str(e)}",
|
|
171
|
+
self.provider_name
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async def chat_completion_stream(
|
|
175
|
+
self, request: ChatRequest
|
|
176
|
+
) -> AsyncIterator[ChatResponse]:
|
|
177
|
+
"""
|
|
178
|
+
Execute streaming chat completion request.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
request: Unified chat request
|
|
182
|
+
|
|
183
|
+
Yields:
|
|
184
|
+
Unified chat response chunks
|
|
185
|
+
|
|
186
|
+
Raises:
|
|
187
|
+
InvalidModelError: If model not supported
|
|
188
|
+
ProviderAPIError: If API call fails
|
|
189
|
+
"""
|
|
190
|
+
if not self.validate_model(request.model):
|
|
191
|
+
raise InvalidModelError(request.model, self.provider_name)
|
|
192
|
+
|
|
193
|
+
# Validate temperature constraints
|
|
194
|
+
constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
|
|
195
|
+
self.validate_temperature(
|
|
196
|
+
request.temperature,
|
|
197
|
+
constraints.get("min_temperature", 0.0),
|
|
198
|
+
constraints.get("max_temperature", 1.0)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Build request body
|
|
202
|
+
body = self._build_request_body(request)
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# Create async client and invoke Bedrock model with streaming
|
|
206
|
+
async with self._session.client("bedrock-runtime") as client:
|
|
207
|
+
response = await client.invoke_model_with_response_stream(
|
|
208
|
+
modelId=request.model,
|
|
209
|
+
contentType="application/json",
|
|
210
|
+
accept="application/json",
|
|
211
|
+
body=json.dumps(body)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Process streaming response
|
|
215
|
+
stream = response.get("body")
|
|
216
|
+
if stream:
|
|
217
|
+
async for event in stream:
|
|
218
|
+
chunk_data = event.get("chunk")
|
|
219
|
+
if chunk_data:
|
|
220
|
+
chunk = json.loads(chunk_data["bytes"].decode())
|
|
221
|
+
yield self._normalize_stream_chunk(chunk, request.model)
|
|
222
|
+
|
|
223
|
+
except ClientError as e:
|
|
224
|
+
error_code = e.response["Error"]["Code"]
|
|
225
|
+
error_message = e.response["Error"]["Message"]
|
|
226
|
+
raise ProviderAPIError(
|
|
227
|
+
f"Bedrock streaming error ({error_code}): {error_message}",
|
|
228
|
+
self.provider_name
|
|
229
|
+
)
|
|
230
|
+
except Exception as e:
|
|
231
|
+
raise ProviderAPIError(
|
|
232
|
+
f"Streaming chat completion failed: {str(e)}",
|
|
233
|
+
self.provider_name
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def _build_request_body(self, request: ChatRequest) -> dict:
|
|
237
|
+
"""
|
|
238
|
+
Build request body based on model family.
|
|
239
|
+
|
|
240
|
+
Different Bedrock models have different request formats:
|
|
241
|
+
- Anthropic Claude: Uses Messages API format
|
|
242
|
+
- Meta Llama: Uses prompt-based format
|
|
243
|
+
- Mistral: Uses messages format
|
|
244
|
+
- Cohere: Uses prompt-based format
|
|
245
|
+
- Amazon Titan: Uses inputText format
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
request: Unified chat request
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Model-specific request body
|
|
252
|
+
"""
|
|
253
|
+
model_id = request.model
|
|
254
|
+
|
|
255
|
+
# Anthropic Claude models
|
|
256
|
+
if model_id.startswith("anthropic.claude"):
|
|
257
|
+
return self._build_anthropic_request(request)
|
|
258
|
+
|
|
259
|
+
# Meta Llama models
|
|
260
|
+
elif model_id.startswith("meta.llama"):
|
|
261
|
+
return self._build_llama_request(request)
|
|
262
|
+
|
|
263
|
+
# Mistral models
|
|
264
|
+
elif model_id.startswith("mistral."):
|
|
265
|
+
return self._build_mistral_request(request)
|
|
266
|
+
|
|
267
|
+
# Cohere models
|
|
268
|
+
elif model_id.startswith("cohere."):
|
|
269
|
+
return self._build_cohere_request(request)
|
|
270
|
+
|
|
271
|
+
# Amazon Nova models (new generation)
|
|
272
|
+
elif model_id.startswith("amazon.nova"):
|
|
273
|
+
return self._build_nova_request(request)
|
|
274
|
+
|
|
275
|
+
# Amazon Titan models (legacy)
|
|
276
|
+
elif model_id.startswith("amazon.titan"):
|
|
277
|
+
return self._build_titan_request(request)
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
raise InvalidModelError(
|
|
281
|
+
f"Unknown model family for {model_id}",
|
|
282
|
+
self.provider_name
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def _build_anthropic_request(self, request: ChatRequest) -> dict:
|
|
286
|
+
"""Build request for Anthropic Claude models."""
|
|
287
|
+
# Separate system message from conversation
|
|
288
|
+
system_message = None
|
|
289
|
+
messages = []
|
|
290
|
+
|
|
291
|
+
for msg in request.messages:
|
|
292
|
+
if msg.role == "system":
|
|
293
|
+
system_message = msg.content
|
|
294
|
+
else:
|
|
295
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
296
|
+
|
|
297
|
+
body = {
|
|
298
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
299
|
+
"messages": messages,
|
|
300
|
+
"max_tokens": request.max_tokens or 4096,
|
|
301
|
+
"temperature": request.temperature,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
if system_message:
|
|
305
|
+
body["system"] = system_message
|
|
306
|
+
|
|
307
|
+
if request.top_p != 1.0:
|
|
308
|
+
body["top_p"] = request.top_p
|
|
309
|
+
|
|
310
|
+
if request.stop:
|
|
311
|
+
body["stop_sequences"] = request.stop
|
|
312
|
+
|
|
313
|
+
return body
|
|
314
|
+
|
|
315
|
+
def _build_llama_request(self, request: ChatRequest) -> dict:
|
|
316
|
+
"""Build request for Meta Llama models."""
|
|
317
|
+
# Llama uses a prompt-based format
|
|
318
|
+
prompt = self._messages_to_prompt(request.messages)
|
|
319
|
+
|
|
320
|
+
return {
|
|
321
|
+
"prompt": prompt,
|
|
322
|
+
"max_gen_len": request.max_tokens or 2048,
|
|
323
|
+
"temperature": request.temperature,
|
|
324
|
+
"top_p": request.top_p,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
def _build_mistral_request(self, request: ChatRequest) -> dict:
|
|
328
|
+
"""Build request for Mistral models."""
|
|
329
|
+
# Convert to prompt format
|
|
330
|
+
prompt = self._messages_to_prompt(request.messages)
|
|
331
|
+
|
|
332
|
+
return {
|
|
333
|
+
"prompt": prompt,
|
|
334
|
+
"max_tokens": request.max_tokens or 2048,
|
|
335
|
+
"temperature": request.temperature,
|
|
336
|
+
"top_p": request.top_p,
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
def _build_cohere_request(self, request: ChatRequest) -> dict:
|
|
340
|
+
"""Build request for Cohere models."""
|
|
341
|
+
# Cohere uses a message-based format similar to OpenAI
|
|
342
|
+
messages = []
|
|
343
|
+
for msg in request.messages:
|
|
344
|
+
messages.append({"role": msg.role, "message": msg.content})
|
|
345
|
+
|
|
346
|
+
return {
|
|
347
|
+
"message": messages[-1]["message"] if messages else "",
|
|
348
|
+
"chat_history": messages[:-1] if len(messages) > 1 else [],
|
|
349
|
+
"max_tokens": request.max_tokens or 2048,
|
|
350
|
+
"temperature": request.temperature,
|
|
351
|
+
"p": request.top_p,
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
def _build_nova_request(self, request: ChatRequest) -> dict:
|
|
355
|
+
"""Build request for Amazon Nova models."""
|
|
356
|
+
# Nova uses messages API similar to Claude
|
|
357
|
+
system_message = None
|
|
358
|
+
messages = []
|
|
359
|
+
|
|
360
|
+
for msg in request.messages:
|
|
361
|
+
if msg.role == "system":
|
|
362
|
+
system_message = msg.content
|
|
363
|
+
else:
|
|
364
|
+
messages.append({"role": msg.role, "content": [{"text": msg.content}]})
|
|
365
|
+
|
|
366
|
+
body = {
|
|
367
|
+
"messages": messages,
|
|
368
|
+
"inferenceConfig": {
|
|
369
|
+
"max_new_tokens": request.max_tokens or 4096,
|
|
370
|
+
"temperature": request.temperature,
|
|
371
|
+
"top_p": request.top_p,
|
|
372
|
+
},
|
|
373
|
+
"schemaVersion": "messages-v1",
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
if system_message:
|
|
377
|
+
body["system"] = [{"text": system_message}]
|
|
378
|
+
|
|
379
|
+
if request.stop:
|
|
380
|
+
body["inferenceConfig"]["stopSequences"] = request.stop
|
|
381
|
+
|
|
382
|
+
return body
|
|
383
|
+
|
|
384
|
+
def _build_titan_request(self, request: ChatRequest) -> dict:
|
|
385
|
+
"""Build request for Amazon Titan models."""
|
|
386
|
+
# Titan uses inputText format
|
|
387
|
+
prompt = self._messages_to_prompt(request.messages)
|
|
388
|
+
|
|
389
|
+
return {
|
|
390
|
+
"inputText": prompt,
|
|
391
|
+
"textGenerationConfig": {
|
|
392
|
+
"maxTokenCount": request.max_tokens or 2048,
|
|
393
|
+
"temperature": request.temperature,
|
|
394
|
+
"topP": request.top_p,
|
|
395
|
+
"stopSequences": request.stop or [],
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
def _messages_to_prompt(self, messages: List) -> str:
|
|
400
|
+
"""
|
|
401
|
+
Convert message list to a single prompt string.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
messages: List of Message objects
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Formatted prompt string
|
|
408
|
+
"""
|
|
409
|
+
prompt_parts = []
|
|
410
|
+
for msg in messages:
|
|
411
|
+
if msg.role == "system":
|
|
412
|
+
prompt_parts.append(f"System: {msg.content}")
|
|
413
|
+
elif msg.role == "user":
|
|
414
|
+
prompt_parts.append(f"User: {msg.content}")
|
|
415
|
+
elif msg.role == "assistant":
|
|
416
|
+
prompt_parts.append(f"Assistant: {msg.content}")
|
|
417
|
+
|
|
418
|
+
return "\n\n".join(prompt_parts) + "\n\nAssistant:"
|
|
419
|
+
|
|
420
|
+
def _normalize_response(self, raw_response: dict, model: str) -> ChatResponse:
|
|
421
|
+
"""
|
|
422
|
+
Convert Bedrock response to unified format.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
raw_response: Raw Bedrock API response
|
|
426
|
+
model: Model ID used
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Normalized ChatResponse with cost
|
|
430
|
+
"""
|
|
431
|
+
# Parse response based on model family
|
|
432
|
+
if model.startswith("anthropic.claude"):
|
|
433
|
+
content = self._parse_anthropic_response(raw_response)
|
|
434
|
+
usage = self._extract_anthropic_usage(raw_response)
|
|
435
|
+
finish_reason = raw_response.get("stop_reason", "stop")
|
|
436
|
+
|
|
437
|
+
elif model.startswith("meta.llama"):
|
|
438
|
+
content = raw_response.get("generation", "")
|
|
439
|
+
usage = self._extract_llama_usage(raw_response, content, model)
|
|
440
|
+
finish_reason = raw_response.get("stop_reason", "stop")
|
|
441
|
+
|
|
442
|
+
elif model.startswith("mistral."):
|
|
443
|
+
content = raw_response.get("outputs", [{}])[0].get("text", "")
|
|
444
|
+
usage = self._estimate_usage(content, model)
|
|
445
|
+
finish_reason = raw_response.get("stop_reason", "stop")
|
|
446
|
+
|
|
447
|
+
elif model.startswith("cohere."):
|
|
448
|
+
content = raw_response.get("text", "")
|
|
449
|
+
usage = self._extract_cohere_usage(raw_response, model)
|
|
450
|
+
finish_reason = raw_response.get("finish_reason", "COMPLETE")
|
|
451
|
+
|
|
452
|
+
elif model.startswith("amazon.nova"):
|
|
453
|
+
content = self._parse_nova_response(raw_response)
|
|
454
|
+
usage = self._extract_nova_usage(raw_response)
|
|
455
|
+
finish_reason = raw_response.get("stopReason", "end_turn")
|
|
456
|
+
|
|
457
|
+
elif model.startswith("amazon.titan"):
|
|
458
|
+
content = raw_response.get("results", [{}])[0].get("outputText", "")
|
|
459
|
+
usage = self._extract_titan_usage(raw_response, model)
|
|
460
|
+
finish_reason = raw_response.get("results", [{}])[0].get("completionReason", "FINISH")
|
|
461
|
+
|
|
462
|
+
else:
|
|
463
|
+
content = str(raw_response)
|
|
464
|
+
usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
|
465
|
+
finish_reason = "stop"
|
|
466
|
+
|
|
467
|
+
# Calculate cost
|
|
468
|
+
cost = self._calculate_cost(usage, model)
|
|
469
|
+
usage.cost_usd = cost
|
|
470
|
+
|
|
471
|
+
return ChatResponse(
|
|
472
|
+
id=raw_response.get("id", f"bedrock-{datetime.now().timestamp()}"),
|
|
473
|
+
model=model,
|
|
474
|
+
content=content,
|
|
475
|
+
finish_reason=finish_reason,
|
|
476
|
+
usage=usage,
|
|
477
|
+
provider=self.provider_name,
|
|
478
|
+
created_at=datetime.now(),
|
|
479
|
+
raw_response=raw_response
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def _parse_anthropic_response(self, response: dict) -> str:
|
|
483
|
+
"""Extract content from Anthropic Claude response."""
|
|
484
|
+
content = ""
|
|
485
|
+
if response.get("content"):
|
|
486
|
+
for block in response["content"]:
|
|
487
|
+
if block.get("type") == "text":
|
|
488
|
+
content += block.get("text", "")
|
|
489
|
+
return content
|
|
490
|
+
|
|
491
|
+
def _extract_anthropic_usage(self, response: dict) -> Usage:
|
|
492
|
+
"""Extract usage from Anthropic Claude response."""
|
|
493
|
+
usage_data = response.get("usage", {})
|
|
494
|
+
return Usage(
|
|
495
|
+
prompt_tokens=usage_data.get("input_tokens", 0),
|
|
496
|
+
completion_tokens=usage_data.get("output_tokens", 0),
|
|
497
|
+
total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
def _extract_llama_usage(self, response: dict, content: str, model: str) -> Usage:
|
|
501
|
+
"""Extract or estimate usage for Llama models."""
|
|
502
|
+
# Llama doesn't always return token counts, so we estimate
|
|
503
|
+
prompt_tokens = response.get("prompt_token_count", 0)
|
|
504
|
+
completion_tokens = response.get("generation_token_count", 0)
|
|
505
|
+
|
|
506
|
+
# If not provided, estimate (rough: 1 token ≈ 4 characters)
|
|
507
|
+
if completion_tokens == 0 and content:
|
|
508
|
+
completion_tokens = len(content) // 4
|
|
509
|
+
|
|
510
|
+
return Usage(
|
|
511
|
+
prompt_tokens=prompt_tokens,
|
|
512
|
+
completion_tokens=completion_tokens,
|
|
513
|
+
total_tokens=prompt_tokens + completion_tokens
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
def _extract_cohere_usage(self, response: dict, model: str) -> Usage:
|
|
517
|
+
"""Extract usage from Cohere response."""
|
|
518
|
+
# Cohere may not always provide token counts
|
|
519
|
+
prompt_tokens = response.get("prompt_tokens", 0)
|
|
520
|
+
completion_tokens = response.get("generation_tokens", 0)
|
|
521
|
+
|
|
522
|
+
return Usage(
|
|
523
|
+
prompt_tokens=prompt_tokens,
|
|
524
|
+
completion_tokens=completion_tokens,
|
|
525
|
+
total_tokens=prompt_tokens + completion_tokens
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def _parse_nova_response(self, response: dict) -> str:
|
|
529
|
+
"""Extract content from Amazon Nova response."""
|
|
530
|
+
content = ""
|
|
531
|
+
output = response.get("output", {})
|
|
532
|
+
if output.get("message"):
|
|
533
|
+
for block in output["message"].get("content", []):
|
|
534
|
+
if block.get("text"):
|
|
535
|
+
content += block["text"]
|
|
536
|
+
return content
|
|
537
|
+
|
|
538
|
+
def _extract_nova_usage(self, response: dict) -> Usage:
|
|
539
|
+
"""Extract usage from Amazon Nova response."""
|
|
540
|
+
usage_data = response.get("usage", {})
|
|
541
|
+
return Usage(
|
|
542
|
+
prompt_tokens=usage_data.get("inputTokens", 0),
|
|
543
|
+
completion_tokens=usage_data.get("outputTokens", 0),
|
|
544
|
+
total_tokens=usage_data.get("totalTokens", 0)
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
def _extract_titan_usage(self, response: dict, model: str) -> Usage:
|
|
548
|
+
"""Extract usage from Titan response."""
|
|
549
|
+
result = response.get("results", [{}])[0]
|
|
550
|
+
prompt_tokens = result.get("inputTextTokenCount", 0)
|
|
551
|
+
completion_tokens = result.get("outputTextTokenCount", 0)
|
|
552
|
+
|
|
553
|
+
return Usage(
|
|
554
|
+
prompt_tokens=prompt_tokens,
|
|
555
|
+
completion_tokens=completion_tokens,
|
|
556
|
+
total_tokens=prompt_tokens + completion_tokens
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def _estimate_usage(self, content: str, model: str) -> Usage:
|
|
560
|
+
"""Estimate token usage when not provided by API."""
|
|
561
|
+
# Rough estimation: 1 token ≈ 4 characters
|
|
562
|
+
completion_tokens = len(content) // 4
|
|
563
|
+
|
|
564
|
+
return Usage(
|
|
565
|
+
prompt_tokens=0, # Can't estimate prompt tokens without request
|
|
566
|
+
completion_tokens=completion_tokens,
|
|
567
|
+
total_tokens=completion_tokens
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
def _normalize_stream_chunk(self, chunk: dict, model: str) -> ChatResponse:
|
|
571
|
+
"""
|
|
572
|
+
Convert streaming chunk to unified format.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
chunk: Raw streaming chunk
|
|
576
|
+
model: Model ID used
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
Normalized ChatResponse chunk
|
|
580
|
+
"""
|
|
581
|
+
# Parse chunk based on model family
|
|
582
|
+
if model.startswith("anthropic.claude"):
|
|
583
|
+
if chunk.get("type") == "content_block_delta":
|
|
584
|
+
content = chunk.get("delta", {}).get("text", "")
|
|
585
|
+
else:
|
|
586
|
+
content = ""
|
|
587
|
+
elif model.startswith("meta.llama"):
|
|
588
|
+
content = chunk.get("generation", "")
|
|
589
|
+
elif model.startswith("mistral."):
|
|
590
|
+
content = chunk.get("outputs", [{}])[0].get("text", "")
|
|
591
|
+
elif model.startswith("amazon.nova"):
|
|
592
|
+
# Nova streaming format
|
|
593
|
+
if chunk.get("contentBlockDelta"):
|
|
594
|
+
content = chunk["contentBlockDelta"].get("delta", {}).get("text", "")
|
|
595
|
+
else:
|
|
596
|
+
content = ""
|
|
597
|
+
elif model.startswith("amazon.titan"):
|
|
598
|
+
content = chunk.get("outputText", "")
|
|
599
|
+
else:
|
|
600
|
+
content = ""
|
|
601
|
+
|
|
602
|
+
return ChatResponse(
|
|
603
|
+
id=f"bedrock-stream-{datetime.now().timestamp()}",
|
|
604
|
+
model=model,
|
|
605
|
+
content=content,
|
|
606
|
+
finish_reason="",
|
|
607
|
+
usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
608
|
+
provider=self.provider_name,
|
|
609
|
+
created_at=datetime.now(),
|
|
610
|
+
raw_response=chunk
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
def _calculate_cost(self, usage: Usage, model: str) -> float:
|
|
614
|
+
"""
|
|
615
|
+
Calculate cost for Bedrock request.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
usage: Token usage information
|
|
619
|
+
model: Model ID used
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
Cost in USD
|
|
623
|
+
"""
|
|
624
|
+
model_info = BEDROCK_MODELS.get(model, {})
|
|
625
|
+
|
|
626
|
+
# Get cost per million tokens
|
|
627
|
+
input_cost_per_mtok = model_info.get("cost_input", 0.0)
|
|
628
|
+
output_cost_per_mtok = model_info.get("cost_output", 0.0)
|
|
629
|
+
|
|
630
|
+
# Calculate cost
|
|
631
|
+
input_cost = (usage.prompt_tokens / 1_000_000) * input_cost_per_mtok
|
|
632
|
+
output_cost = (usage.completion_tokens / 1_000_000) * output_cost_per_mtok
|
|
633
|
+
|
|
634
|
+
return input_cost + output_cost
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""DeepSeek provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import DEEPSEEK_MODELS, PROVIDER_BASE_URLS
|
|
7
|
+
from ..exceptions import AuthenticationError
|
|
8
|
+
from .openai_compatible import OpenAICompatibleProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DeepSeekProvider(OpenAICompatibleProvider):
|
|
12
|
+
"""DeepSeek provider using OpenAI-compatible API."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
api_key: Optional[str] = None,
|
|
17
|
+
config: dict = None
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Initialize DeepSeek provider.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_key: DeepSeek API key (defaults to DEEPSEEK_API_KEY env var)
|
|
24
|
+
config: Optional provider-specific configuration
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
AuthenticationError: If API key not provided
|
|
28
|
+
"""
|
|
29
|
+
api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
|
30
|
+
if not api_key:
|
|
31
|
+
raise AuthenticationError("deepseek")
|
|
32
|
+
|
|
33
|
+
base_url = PROVIDER_BASE_URLS["deepseek"]
|
|
34
|
+
super().__init__(api_key, base_url, DEEPSEEK_MODELS, config)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def provider_name(self) -> str:
|
|
38
|
+
"""Return provider name."""
|
|
39
|
+
return "deepseek"
|