stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,634 @@
1
+ """AWS Bedrock provider implementation."""
2
+
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+ from typing import AsyncIterator, List, Optional
7
+
8
+ try:
9
+ import aioboto3
10
+ from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
11
+ except ImportError:
12
+ raise ImportError(
13
+ "aioboto3 is required for AWS Bedrock async support. "
14
+ "Install with: pip install aioboto3>=12.0.0"
15
+ )
16
+
17
+ from ..config import BEDROCK_MODELS, PROVIDER_CONSTRAINTS
18
+ from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
19
+ from ..models import ChatRequest, ChatResponse, Usage
20
+ from .base import BaseProvider
21
+
22
+
23
+ class BedrockProvider(BaseProvider):
24
+ """AWS Bedrock provider implementation using aioboto3 for async support."""
25
+
26
+ def __init__(
27
+ self,
28
+ api_key: Optional[str] = None, # For compatibility with LLMClient (AWS_BEARER_TOKEN_BEDROCK)
29
+ aws_access_key_id: Optional[str] = None,
30
+ aws_secret_access_key: Optional[str] = None,
31
+ aws_session_token: Optional[str] = None,
32
+ region_name: Optional[str] = None,
33
+ config: dict = None
34
+ ):
35
+ """
36
+ Initialize AWS Bedrock provider.
37
+
38
+ Args:
39
+ api_key: AWS bearer token (defaults to AWS_BEARER_TOKEN_BEDROCK env var)
40
+ or for compatibility with LLMClient interface
41
+ aws_access_key_id: AWS access key (defaults to AWS_ACCESS_KEY_ID env var)
42
+ aws_secret_access_key: AWS secret key (defaults to AWS_SECRET_ACCESS_KEY env var)
43
+ aws_session_token: AWS session token (defaults to AWS_SESSION_TOKEN env var)
44
+ region_name: AWS region (defaults to AWS_DEFAULT_REGION or us-east-1)
45
+ config: Optional provider-specific configuration
46
+
47
+ Raises:
48
+ ValueError: If AWS credentials are not available (with helpful setup instructions)
49
+ """
50
+ # AWS Bedrock supports multiple authentication methods:
51
+ # 1. Bearer token (AWS_BEARER_TOKEN_BEDROCK)
52
+ # 2. Access key + secret key (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY)
53
+ # 3. IAM roles (when running on AWS infrastructure)
54
+ # 4. ~/.aws/credentials file
55
+
56
+ # Check for bearer token first (simplest method)
57
+ bearer_token = api_key or os.getenv("AWS_BEARER_TOKEN_BEDROCK")
58
+
59
+ # Check for access key credentials
60
+ self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
61
+ self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
62
+ self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
63
+ self.region_name = region_name or os.getenv("AWS_DEFAULT_REGION", "us-east-1")
64
+
65
+ # Use APIKeyHelper for better error messages if no credentials found
66
+ if not bearer_token and not (self.aws_access_key_id and self.aws_secret_access_key):
67
+ from ..api_key_helper import get_api_key_or_error
68
+ try:
69
+ get_api_key_or_error("bedrock", bearer_token)
70
+ except ValueError:
71
+ # Allow to proceed if using IAM roles or ~/.aws/credentials
72
+ # boto3 will handle the credential chain
73
+ pass
74
+
75
+ # BaseProvider expects api_key, so we'll use access_key_id as a placeholder
76
+ # (Bedrock doesn't use API keys like other providers)
77
+ super().__init__(self.aws_access_key_id or "aws-credentials", config)
78
+ self._initialize_client()
79
+
80
+ def _initialize_client(self) -> None:
81
+ """Initialize AWS Bedrock session for async client creation."""
82
+ try:
83
+ # Create aioboto3 session with explicit credentials if provided
84
+ session_params = {"region_name": self.region_name}
85
+ if self.aws_access_key_id and self.aws_secret_access_key:
86
+ session_params["aws_access_key_id"] = self.aws_access_key_id
87
+ session_params["aws_secret_access_key"] = self.aws_secret_access_key
88
+ if self.aws_session_token:
89
+ session_params["aws_session_token"] = self.aws_session_token
90
+
91
+ # Store session for async client creation
92
+ # aioboto3 clients must be created within async context
93
+ self._session = aioboto3.Session(**session_params)
94
+ self._client = None # Will be created in async context
95
+
96
+ except NoCredentialsError:
97
+ raise AuthenticationError(
98
+ "AWS credentials not found. Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
99
+ "environment variables or configure ~/.aws/credentials"
100
+ )
101
+ except Exception as e:
102
+ raise ProviderAPIError(
103
+ f"Failed to initialize AWS Bedrock session: {str(e)}",
104
+ "bedrock"
105
+ )
106
+
107
+ @property
108
+ def provider_name(self) -> str:
109
+ """Return provider name."""
110
+ return "bedrock"
111
+
112
+ def get_supported_models(self) -> List[str]:
113
+ """Return list of supported Bedrock models."""
114
+ return list(BEDROCK_MODELS.keys())
115
+
116
+ async def chat_completion(self, request: ChatRequest) -> ChatResponse:
117
+ """
118
+ Execute chat completion request using Bedrock.
119
+
120
+ Args:
121
+ request: Unified chat request
122
+
123
+ Returns:
124
+ Unified chat response with cost tracking
125
+
126
+ Raises:
127
+ InvalidModelError: If model not supported
128
+ ProviderAPIError: If API call fails
129
+ """
130
+ if not self.validate_model(request.model):
131
+ raise InvalidModelError(request.model, self.provider_name)
132
+
133
+ # Validate temperature constraints for Bedrock (0.0 to 1.0)
134
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
135
+ self.validate_temperature(
136
+ request.temperature,
137
+ constraints.get("min_temperature", 0.0),
138
+ constraints.get("max_temperature", 1.0)
139
+ )
140
+
141
+ # Build request body based on model family
142
+ body = self._build_request_body(request)
143
+
144
+ try:
145
+ # Create async client and invoke Bedrock model
146
+ async with self._session.client("bedrock-runtime") as client:
147
+ response = await client.invoke_model(
148
+ modelId=request.model,
149
+ contentType="application/json",
150
+ accept="application/json",
151
+ body=json.dumps(body)
152
+ )
153
+
154
+ # Parse response - aioboto3 returns StreamingBody
155
+ response_body_bytes = await response["body"].read()
156
+ response_body = json.loads(response_body_bytes)
157
+
158
+ # Normalize response based on model family
159
+ return self._normalize_response(response_body, request.model)
160
+
161
+ except ClientError as e:
162
+ error_code = e.response["Error"]["Code"]
163
+ error_message = e.response["Error"]["Message"]
164
+ raise ProviderAPIError(
165
+ f"Bedrock API error ({error_code}): {error_message}",
166
+ self.provider_name
167
+ )
168
+ except Exception as e:
169
+ raise ProviderAPIError(
170
+ f"Chat completion failed: {str(e)}",
171
+ self.provider_name
172
+ )
173
+
174
+ async def chat_completion_stream(
175
+ self, request: ChatRequest
176
+ ) -> AsyncIterator[ChatResponse]:
177
+ """
178
+ Execute streaming chat completion request.
179
+
180
+ Args:
181
+ request: Unified chat request
182
+
183
+ Yields:
184
+ Unified chat response chunks
185
+
186
+ Raises:
187
+ InvalidModelError: If model not supported
188
+ ProviderAPIError: If API call fails
189
+ """
190
+ if not self.validate_model(request.model):
191
+ raise InvalidModelError(request.model, self.provider_name)
192
+
193
+ # Validate temperature constraints
194
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
195
+ self.validate_temperature(
196
+ request.temperature,
197
+ constraints.get("min_temperature", 0.0),
198
+ constraints.get("max_temperature", 1.0)
199
+ )
200
+
201
+ # Build request body
202
+ body = self._build_request_body(request)
203
+
204
+ try:
205
+ # Create async client and invoke Bedrock model with streaming
206
+ async with self._session.client("bedrock-runtime") as client:
207
+ response = await client.invoke_model_with_response_stream(
208
+ modelId=request.model,
209
+ contentType="application/json",
210
+ accept="application/json",
211
+ body=json.dumps(body)
212
+ )
213
+
214
+ # Process streaming response
215
+ stream = response.get("body")
216
+ if stream:
217
+ async for event in stream:
218
+ chunk_data = event.get("chunk")
219
+ if chunk_data:
220
+ chunk = json.loads(chunk_data["bytes"].decode())
221
+ yield self._normalize_stream_chunk(chunk, request.model)
222
+
223
+ except ClientError as e:
224
+ error_code = e.response["Error"]["Code"]
225
+ error_message = e.response["Error"]["Message"]
226
+ raise ProviderAPIError(
227
+ f"Bedrock streaming error ({error_code}): {error_message}",
228
+ self.provider_name
229
+ )
230
+ except Exception as e:
231
+ raise ProviderAPIError(
232
+ f"Streaming chat completion failed: {str(e)}",
233
+ self.provider_name
234
+ )
235
+
236
+ def _build_request_body(self, request: ChatRequest) -> dict:
237
+ """
238
+ Build request body based on model family.
239
+
240
+ Different Bedrock models have different request formats:
241
+ - Anthropic Claude: Uses Messages API format
242
+ - Meta Llama: Uses prompt-based format
243
+ - Mistral: Uses messages format
244
+ - Cohere: Uses prompt-based format
245
+ - Amazon Titan: Uses inputText format
246
+
247
+ Args:
248
+ request: Unified chat request
249
+
250
+ Returns:
251
+ Model-specific request body
252
+ """
253
+ model_id = request.model
254
+
255
+ # Anthropic Claude models
256
+ if model_id.startswith("anthropic.claude"):
257
+ return self._build_anthropic_request(request)
258
+
259
+ # Meta Llama models
260
+ elif model_id.startswith("meta.llama"):
261
+ return self._build_llama_request(request)
262
+
263
+ # Mistral models
264
+ elif model_id.startswith("mistral."):
265
+ return self._build_mistral_request(request)
266
+
267
+ # Cohere models
268
+ elif model_id.startswith("cohere."):
269
+ return self._build_cohere_request(request)
270
+
271
+ # Amazon Nova models (new generation)
272
+ elif model_id.startswith("amazon.nova"):
273
+ return self._build_nova_request(request)
274
+
275
+ # Amazon Titan models (legacy)
276
+ elif model_id.startswith("amazon.titan"):
277
+ return self._build_titan_request(request)
278
+
279
+ else:
280
+ raise InvalidModelError(
281
+ f"Unknown model family for {model_id}",
282
+ self.provider_name
283
+ )
284
+
285
+ def _build_anthropic_request(self, request: ChatRequest) -> dict:
286
+ """Build request for Anthropic Claude models."""
287
+ # Separate system message from conversation
288
+ system_message = None
289
+ messages = []
290
+
291
+ for msg in request.messages:
292
+ if msg.role == "system":
293
+ system_message = msg.content
294
+ else:
295
+ messages.append({"role": msg.role, "content": msg.content})
296
+
297
+ body = {
298
+ "anthropic_version": "bedrock-2023-05-31",
299
+ "messages": messages,
300
+ "max_tokens": request.max_tokens or 4096,
301
+ "temperature": request.temperature,
302
+ }
303
+
304
+ if system_message:
305
+ body["system"] = system_message
306
+
307
+ if request.top_p != 1.0:
308
+ body["top_p"] = request.top_p
309
+
310
+ if request.stop:
311
+ body["stop_sequences"] = request.stop
312
+
313
+ return body
314
+
315
+ def _build_llama_request(self, request: ChatRequest) -> dict:
316
+ """Build request for Meta Llama models."""
317
+ # Llama uses a prompt-based format
318
+ prompt = self._messages_to_prompt(request.messages)
319
+
320
+ return {
321
+ "prompt": prompt,
322
+ "max_gen_len": request.max_tokens or 2048,
323
+ "temperature": request.temperature,
324
+ "top_p": request.top_p,
325
+ }
326
+
327
+ def _build_mistral_request(self, request: ChatRequest) -> dict:
328
+ """Build request for Mistral models."""
329
+ # Convert to prompt format
330
+ prompt = self._messages_to_prompt(request.messages)
331
+
332
+ return {
333
+ "prompt": prompt,
334
+ "max_tokens": request.max_tokens or 2048,
335
+ "temperature": request.temperature,
336
+ "top_p": request.top_p,
337
+ }
338
+
339
+ def _build_cohere_request(self, request: ChatRequest) -> dict:
340
+ """Build request for Cohere models."""
341
+ # Cohere uses a message-based format similar to OpenAI
342
+ messages = []
343
+ for msg in request.messages:
344
+ messages.append({"role": msg.role, "message": msg.content})
345
+
346
+ return {
347
+ "message": messages[-1]["message"] if messages else "",
348
+ "chat_history": messages[:-1] if len(messages) > 1 else [],
349
+ "max_tokens": request.max_tokens or 2048,
350
+ "temperature": request.temperature,
351
+ "p": request.top_p,
352
+ }
353
+
354
+ def _build_nova_request(self, request: ChatRequest) -> dict:
355
+ """Build request for Amazon Nova models."""
356
+ # Nova uses messages API similar to Claude
357
+ system_message = None
358
+ messages = []
359
+
360
+ for msg in request.messages:
361
+ if msg.role == "system":
362
+ system_message = msg.content
363
+ else:
364
+ messages.append({"role": msg.role, "content": [{"text": msg.content}]})
365
+
366
+ body = {
367
+ "messages": messages,
368
+ "inferenceConfig": {
369
+ "max_new_tokens": request.max_tokens or 4096,
370
+ "temperature": request.temperature,
371
+ "top_p": request.top_p,
372
+ },
373
+ "schemaVersion": "messages-v1",
374
+ }
375
+
376
+ if system_message:
377
+ body["system"] = [{"text": system_message}]
378
+
379
+ if request.stop:
380
+ body["inferenceConfig"]["stopSequences"] = request.stop
381
+
382
+ return body
383
+
384
+ def _build_titan_request(self, request: ChatRequest) -> dict:
385
+ """Build request for Amazon Titan models."""
386
+ # Titan uses inputText format
387
+ prompt = self._messages_to_prompt(request.messages)
388
+
389
+ return {
390
+ "inputText": prompt,
391
+ "textGenerationConfig": {
392
+ "maxTokenCount": request.max_tokens or 2048,
393
+ "temperature": request.temperature,
394
+ "topP": request.top_p,
395
+ "stopSequences": request.stop or [],
396
+ }
397
+ }
398
+
399
+ def _messages_to_prompt(self, messages: List) -> str:
400
+ """
401
+ Convert message list to a single prompt string.
402
+
403
+ Args:
404
+ messages: List of Message objects
405
+
406
+ Returns:
407
+ Formatted prompt string
408
+ """
409
+ prompt_parts = []
410
+ for msg in messages:
411
+ if msg.role == "system":
412
+ prompt_parts.append(f"System: {msg.content}")
413
+ elif msg.role == "user":
414
+ prompt_parts.append(f"User: {msg.content}")
415
+ elif msg.role == "assistant":
416
+ prompt_parts.append(f"Assistant: {msg.content}")
417
+
418
+ return "\n\n".join(prompt_parts) + "\n\nAssistant:"
419
+
420
+ def _normalize_response(self, raw_response: dict, model: str) -> ChatResponse:
421
+ """
422
+ Convert Bedrock response to unified format.
423
+
424
+ Args:
425
+ raw_response: Raw Bedrock API response
426
+ model: Model ID used
427
+
428
+ Returns:
429
+ Normalized ChatResponse with cost
430
+ """
431
+ # Parse response based on model family
432
+ if model.startswith("anthropic.claude"):
433
+ content = self._parse_anthropic_response(raw_response)
434
+ usage = self._extract_anthropic_usage(raw_response)
435
+ finish_reason = raw_response.get("stop_reason", "stop")
436
+
437
+ elif model.startswith("meta.llama"):
438
+ content = raw_response.get("generation", "")
439
+ usage = self._extract_llama_usage(raw_response, content, model)
440
+ finish_reason = raw_response.get("stop_reason", "stop")
441
+
442
+ elif model.startswith("mistral."):
443
+ content = raw_response.get("outputs", [{}])[0].get("text", "")
444
+ usage = self._estimate_usage(content, model)
445
+ finish_reason = raw_response.get("stop_reason", "stop")
446
+
447
+ elif model.startswith("cohere."):
448
+ content = raw_response.get("text", "")
449
+ usage = self._extract_cohere_usage(raw_response, model)
450
+ finish_reason = raw_response.get("finish_reason", "COMPLETE")
451
+
452
+ elif model.startswith("amazon.nova"):
453
+ content = self._parse_nova_response(raw_response)
454
+ usage = self._extract_nova_usage(raw_response)
455
+ finish_reason = raw_response.get("stopReason", "end_turn")
456
+
457
+ elif model.startswith("amazon.titan"):
458
+ content = raw_response.get("results", [{}])[0].get("outputText", "")
459
+ usage = self._extract_titan_usage(raw_response, model)
460
+ finish_reason = raw_response.get("results", [{}])[0].get("completionReason", "FINISH")
461
+
462
+ else:
463
+ content = str(raw_response)
464
+ usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
465
+ finish_reason = "stop"
466
+
467
+ # Calculate cost
468
+ cost = self._calculate_cost(usage, model)
469
+ usage.cost_usd = cost
470
+
471
+ return ChatResponse(
472
+ id=raw_response.get("id", f"bedrock-{datetime.now().timestamp()}"),
473
+ model=model,
474
+ content=content,
475
+ finish_reason=finish_reason,
476
+ usage=usage,
477
+ provider=self.provider_name,
478
+ created_at=datetime.now(),
479
+ raw_response=raw_response
480
+ )
481
+
482
+ def _parse_anthropic_response(self, response: dict) -> str:
483
+ """Extract content from Anthropic Claude response."""
484
+ content = ""
485
+ if response.get("content"):
486
+ for block in response["content"]:
487
+ if block.get("type") == "text":
488
+ content += block.get("text", "")
489
+ return content
490
+
491
+ def _extract_anthropic_usage(self, response: dict) -> Usage:
492
+ """Extract usage from Anthropic Claude response."""
493
+ usage_data = response.get("usage", {})
494
+ return Usage(
495
+ prompt_tokens=usage_data.get("input_tokens", 0),
496
+ completion_tokens=usage_data.get("output_tokens", 0),
497
+ total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0)
498
+ )
499
+
500
+ def _extract_llama_usage(self, response: dict, content: str, model: str) -> Usage:
501
+ """Extract or estimate usage for Llama models."""
502
+ # Llama doesn't always return token counts, so we estimate
503
+ prompt_tokens = response.get("prompt_token_count", 0)
504
+ completion_tokens = response.get("generation_token_count", 0)
505
+
506
+ # If not provided, estimate (rough: 1 token ≈ 4 characters)
507
+ if completion_tokens == 0 and content:
508
+ completion_tokens = len(content) // 4
509
+
510
+ return Usage(
511
+ prompt_tokens=prompt_tokens,
512
+ completion_tokens=completion_tokens,
513
+ total_tokens=prompt_tokens + completion_tokens
514
+ )
515
+
516
+ def _extract_cohere_usage(self, response: dict, model: str) -> Usage:
517
+ """Extract usage from Cohere response."""
518
+ # Cohere may not always provide token counts
519
+ prompt_tokens = response.get("prompt_tokens", 0)
520
+ completion_tokens = response.get("generation_tokens", 0)
521
+
522
+ return Usage(
523
+ prompt_tokens=prompt_tokens,
524
+ completion_tokens=completion_tokens,
525
+ total_tokens=prompt_tokens + completion_tokens
526
+ )
527
+
528
+ def _parse_nova_response(self, response: dict) -> str:
529
+ """Extract content from Amazon Nova response."""
530
+ content = ""
531
+ output = response.get("output", {})
532
+ if output.get("message"):
533
+ for block in output["message"].get("content", []):
534
+ if block.get("text"):
535
+ content += block["text"]
536
+ return content
537
+
538
+ def _extract_nova_usage(self, response: dict) -> Usage:
539
+ """Extract usage from Amazon Nova response."""
540
+ usage_data = response.get("usage", {})
541
+ return Usage(
542
+ prompt_tokens=usage_data.get("inputTokens", 0),
543
+ completion_tokens=usage_data.get("outputTokens", 0),
544
+ total_tokens=usage_data.get("totalTokens", 0)
545
+ )
546
+
547
+ def _extract_titan_usage(self, response: dict, model: str) -> Usage:
548
+ """Extract usage from Titan response."""
549
+ result = response.get("results", [{}])[0]
550
+ prompt_tokens = result.get("inputTextTokenCount", 0)
551
+ completion_tokens = result.get("outputTextTokenCount", 0)
552
+
553
+ return Usage(
554
+ prompt_tokens=prompt_tokens,
555
+ completion_tokens=completion_tokens,
556
+ total_tokens=prompt_tokens + completion_tokens
557
+ )
558
+
559
+ def _estimate_usage(self, content: str, model: str) -> Usage:
560
+ """Estimate token usage when not provided by API."""
561
+ # Rough estimation: 1 token ≈ 4 characters
562
+ completion_tokens = len(content) // 4
563
+
564
+ return Usage(
565
+ prompt_tokens=0, # Can't estimate prompt tokens without request
566
+ completion_tokens=completion_tokens,
567
+ total_tokens=completion_tokens
568
+ )
569
+
570
+ def _normalize_stream_chunk(self, chunk: dict, model: str) -> ChatResponse:
571
+ """
572
+ Convert streaming chunk to unified format.
573
+
574
+ Args:
575
+ chunk: Raw streaming chunk
576
+ model: Model ID used
577
+
578
+ Returns:
579
+ Normalized ChatResponse chunk
580
+ """
581
+ # Parse chunk based on model family
582
+ if model.startswith("anthropic.claude"):
583
+ if chunk.get("type") == "content_block_delta":
584
+ content = chunk.get("delta", {}).get("text", "")
585
+ else:
586
+ content = ""
587
+ elif model.startswith("meta.llama"):
588
+ content = chunk.get("generation", "")
589
+ elif model.startswith("mistral."):
590
+ content = chunk.get("outputs", [{}])[0].get("text", "")
591
+ elif model.startswith("amazon.nova"):
592
+ # Nova streaming format
593
+ if chunk.get("contentBlockDelta"):
594
+ content = chunk["contentBlockDelta"].get("delta", {}).get("text", "")
595
+ else:
596
+ content = ""
597
+ elif model.startswith("amazon.titan"):
598
+ content = chunk.get("outputText", "")
599
+ else:
600
+ content = ""
601
+
602
+ return ChatResponse(
603
+ id=f"bedrock-stream-{datetime.now().timestamp()}",
604
+ model=model,
605
+ content=content,
606
+ finish_reason="",
607
+ usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
608
+ provider=self.provider_name,
609
+ created_at=datetime.now(),
610
+ raw_response=chunk
611
+ )
612
+
613
+ def _calculate_cost(self, usage: Usage, model: str) -> float:
614
+ """
615
+ Calculate cost for Bedrock request.
616
+
617
+ Args:
618
+ usage: Token usage information
619
+ model: Model ID used
620
+
621
+ Returns:
622
+ Cost in USD
623
+ """
624
+ model_info = BEDROCK_MODELS.get(model, {})
625
+
626
+ # Get cost per million tokens
627
+ input_cost_per_mtok = model_info.get("cost_input", 0.0)
628
+ output_cost_per_mtok = model_info.get("cost_output", 0.0)
629
+
630
+ # Calculate cost
631
+ input_cost = (usage.prompt_tokens / 1_000_000) * input_cost_per_mtok
632
+ output_cost = (usage.completion_tokens / 1_000_000) * output_cost_per_mtok
633
+
634
+ return input_cost + output_cost
@@ -0,0 +1,39 @@
1
+ """DeepSeek provider implementation."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import DEEPSEEK_MODELS, PROVIDER_BASE_URLS
7
+ from ..exceptions import AuthenticationError
8
+ from .openai_compatible import OpenAICompatibleProvider
9
+
10
+
11
+ class DeepSeekProvider(OpenAICompatibleProvider):
12
+ """DeepSeek provider using OpenAI-compatible API."""
13
+
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ config: dict = None
18
+ ):
19
+ """
20
+ Initialize DeepSeek provider.
21
+
22
+ Args:
23
+ api_key: DeepSeek API key (defaults to DEEPSEEK_API_KEY env var)
24
+ config: Optional provider-specific configuration
25
+
26
+ Raises:
27
+ AuthenticationError: If API key not provided
28
+ """
29
+ api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
30
+ if not api_key:
31
+ raise AuthenticationError("deepseek")
32
+
33
+ base_url = PROVIDER_BASE_URLS["deepseek"]
34
+ super().__init__(api_key, base_url, DEEPSEEK_MODELS, config)
35
+
36
+ @property
37
+ def provider_name(self) -> str:
38
+ """Return provider name."""
39
+ return "deepseek"