stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,330 @@
1
+ """Anthropic provider implementation."""
2
+
3
+ import os
4
+ from datetime import datetime
5
+ from typing import AsyncIterator, List, Optional
6
+
7
+ from anthropic import AsyncAnthropic
8
+
9
+ from ..config import ANTHROPIC_MODELS, PROVIDER_CONSTRAINTS
10
+ from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
11
+ from ..models import ChatRequest, ChatResponse, Usage
12
+ from .base import BaseProvider
13
+
14
+
15
+ class AnthropicProvider(BaseProvider):
16
+ """Anthropic provider implementation with Messages API."""
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: Optional[str] = None,
21
+ config: dict = None
22
+ ):
23
+ """
24
+ Initialize Anthropic provider.
25
+
26
+ Args:
27
+ api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
28
+ config: Optional provider-specific configuration
29
+
30
+ Raises:
31
+ AuthenticationError: If API key not provided
32
+ """
33
+ api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
34
+ if not api_key:
35
+ raise AuthenticationError("anthropic")
36
+ super().__init__(api_key, config)
37
+ self._initialize_client()
38
+
39
+ def _initialize_client(self) -> None:
40
+ """Initialize Anthropic async client."""
41
+ try:
42
+ self._client = AsyncAnthropic(api_key=self.api_key)
43
+ except Exception as e:
44
+ raise ProviderAPIError(
45
+ f"Failed to initialize Anthropic client: {str(e)}",
46
+ "anthropic"
47
+ )
48
+
49
+ @property
50
+ def provider_name(self) -> str:
51
+ """Return provider name."""
52
+ return "anthropic"
53
+
54
+ def get_supported_models(self) -> List[str]:
55
+ """Return list of supported Anthropic models."""
56
+ return list(ANTHROPIC_MODELS.keys())
57
+
58
+ def supports_caching(self, model: str) -> bool:
59
+ """Check if model supports prompt caching."""
60
+ model_info = ANTHROPIC_MODELS.get(model, {})
61
+ return model_info.get("supports_caching", False)
62
+
63
+ async def chat_completion(self, request: ChatRequest) -> ChatResponse:
64
+ """
65
+ Execute chat completion request using Messages API.
66
+
67
+ Args:
68
+ request: Unified chat request
69
+
70
+ Returns:
71
+ Unified chat response with cost tracking
72
+
73
+ Raises:
74
+ InvalidModelError: If model not supported
75
+ ProviderAPIError: If API call fails
76
+ """
77
+ if not self.validate_model(request.model):
78
+ raise InvalidModelError(request.model, self.provider_name)
79
+
80
+ # Validate temperature constraints for Anthropic (0.0 to 1.0)
81
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
82
+ self.validate_temperature(
83
+ request.temperature,
84
+ constraints.get("min_temperature", 0.0),
85
+ constraints.get("max_temperature", 1.0)
86
+ )
87
+
88
+ # Convert messages to Anthropic format
89
+ # Anthropic requires system message separate from messages array
90
+ system_message = None
91
+ messages = []
92
+
93
+ for msg in request.messages:
94
+ if msg.role == "system":
95
+ system_message = msg.content
96
+ else:
97
+ message_dict = {"role": msg.role, "content": msg.content}
98
+ # Add cache_control if present and model supports caching
99
+ if msg.cache_control and self.supports_caching(request.model):
100
+ message_dict["cache_control"] = msg.cache_control
101
+ messages.append(message_dict)
102
+
103
+ # Build Anthropic-specific request parameters
104
+ anthropic_params = {
105
+ "model": request.model,
106
+ "messages": messages,
107
+ "max_tokens": request.max_tokens or 4096, # Anthropic requires max_tokens
108
+ }
109
+
110
+ # Anthropic only allows one of temperature or top_p
111
+ # Prefer temperature if it's not the default, otherwise use top_p if it's not default
112
+ # Default temperature is 0.7, default top_p is 1.0
113
+ if request.temperature != 0.7:
114
+ # Temperature was explicitly set, use it
115
+ anthropic_params["temperature"] = request.temperature
116
+ elif request.top_p != 1.0:
117
+ # top_p was explicitly set (not default), use it
118
+ anthropic_params["top_p"] = request.top_p
119
+ else:
120
+ # Both are defaults, use temperature
121
+ anthropic_params["temperature"] = request.temperature
122
+
123
+ # Add system message if present
124
+ if system_message:
125
+ anthropic_params["system"] = system_message
126
+
127
+ # Add optional parameters
128
+ if request.stop:
129
+ anthropic_params["stop_sequences"] = request.stop
130
+
131
+ # Add any extra params
132
+ if request.extra_params:
133
+ anthropic_params.update(request.extra_params)
134
+
135
+ try:
136
+ # Make API request
137
+ raw_response = await self._client.messages.create(**anthropic_params)
138
+ # Normalize and return
139
+ return self._normalize_response(raw_response.model_dump())
140
+ except Exception as e:
141
+ raise ProviderAPIError(
142
+ f"Chat completion failed: {str(e)}",
143
+ self.provider_name
144
+ )
145
+
146
+ async def chat_completion_stream(
147
+ self, request: ChatRequest
148
+ ) -> AsyncIterator[ChatResponse]:
149
+ """
150
+ Execute streaming chat completion request.
151
+
152
+ Args:
153
+ request: Unified chat request
154
+
155
+ Yields:
156
+ Unified chat response chunks
157
+
158
+ Raises:
159
+ InvalidModelError: If model not supported
160
+ ProviderAPIError: If API call fails
161
+ """
162
+ if not self.validate_model(request.model):
163
+ raise InvalidModelError(request.model, self.provider_name)
164
+
165
+ # Validate temperature constraints for Anthropic (0.0 to 1.0)
166
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
167
+ self.validate_temperature(
168
+ request.temperature,
169
+ constraints.get("min_temperature", 0.0),
170
+ constraints.get("max_temperature", 1.0)
171
+ )
172
+
173
+ # Convert messages to Anthropic format
174
+ system_message = None
175
+ messages = []
176
+
177
+ for msg in request.messages:
178
+ if msg.role == "system":
179
+ system_message = msg.content
180
+ else:
181
+ messages.append({"role": msg.role, "content": msg.content})
182
+
183
+ # Build request parameters
184
+ anthropic_params = {
185
+ "model": request.model,
186
+ "messages": messages,
187
+ "temperature": request.temperature,
188
+ "max_tokens": request.max_tokens or 4096,
189
+ }
190
+
191
+ if system_message:
192
+ anthropic_params["system"] = system_message
193
+
194
+ try:
195
+ async with self._client.messages.stream(**anthropic_params) as stream:
196
+ async for chunk in stream.text_stream:
197
+ yield self._normalize_stream_chunk(chunk)
198
+ except Exception as e:
199
+ raise ProviderAPIError(
200
+ f"Streaming chat completion failed: {str(e)}",
201
+ self.provider_name
202
+ )
203
+
204
+ def _normalize_response(self, raw_response: dict) -> ChatResponse:
205
+ """
206
+ Convert Anthropic response to unified format.
207
+
208
+ Args:
209
+ raw_response: Raw Anthropic API response
210
+
211
+ Returns:
212
+ Normalized ChatResponse with cost
213
+ """
214
+ # Extract content from response
215
+ content = ""
216
+ if raw_response.get("content"):
217
+ for block in raw_response["content"]:
218
+ if block.get("type") == "text":
219
+ content += block.get("text", "")
220
+
221
+ # Extract token usage
222
+ usage_dict = raw_response.get("usage", {})
223
+ usage = Usage(
224
+ prompt_tokens=usage_dict.get("input_tokens", 0),
225
+ completion_tokens=usage_dict.get("output_tokens", 0),
226
+ total_tokens=usage_dict.get("input_tokens", 0) + usage_dict.get("output_tokens", 0),
227
+ cache_creation_tokens=usage_dict.get("cache_creation_input_tokens", 0),
228
+ cache_read_tokens=usage_dict.get("cache_read_input_tokens", 0),
229
+ )
230
+
231
+ # Calculate cost including cache costs
232
+ base_cost = self._calculate_cost(usage, raw_response["model"])
233
+ cache_cost = self._calculate_cache_cost(
234
+ usage.cache_creation_tokens,
235
+ usage.cache_read_tokens,
236
+ raw_response["model"]
237
+ )
238
+ usage.cost_usd = base_cost + cache_cost
239
+
240
+ # Add cost breakdown
241
+ if usage.cache_creation_tokens > 0 or usage.cache_read_tokens > 0:
242
+ usage.cost_breakdown = {
243
+ "base_cost": base_cost,
244
+ "cache_cost": cache_cost,
245
+ "total_cost": usage.cost_usd,
246
+ }
247
+
248
+ return ChatResponse(
249
+ id=raw_response["id"],
250
+ model=raw_response["model"],
251
+ content=content,
252
+ finish_reason=raw_response.get("stop_reason", "stop"),
253
+ usage=usage,
254
+ provider=self.provider_name,
255
+ created_at=datetime.now(), # Anthropic doesn't provide timestamp
256
+ raw_response=raw_response,
257
+ )
258
+
259
+ def _normalize_stream_chunk(self, chunk: str) -> ChatResponse:
260
+ """Normalize streaming chunk to ChatResponse format."""
261
+ return ChatResponse(
262
+ id="",
263
+ model="",
264
+ content=chunk,
265
+ finish_reason="",
266
+ usage=Usage(
267
+ prompt_tokens=0,
268
+ completion_tokens=0,
269
+ total_tokens=0
270
+ ),
271
+ provider=self.provider_name,
272
+ created_at=datetime.now(),
273
+ raw_response={},
274
+ )
275
+
276
+ def _calculate_cost(self, usage: Usage, model: str) -> float:
277
+ """
278
+ Calculate cost in USD based on token usage (excluding cache costs).
279
+
280
+ Args:
281
+ usage: Token usage information
282
+ model: Model name used
283
+
284
+ Returns:
285
+ Cost in USD
286
+ """
287
+ model_info = ANTHROPIC_MODELS.get(model, {})
288
+ cost_input = model_info.get("cost_input", 0.0)
289
+ cost_output = model_info.get("cost_output", 0.0)
290
+
291
+ # Calculate non-cached prompt tokens
292
+ non_cached_prompt_tokens = usage.prompt_tokens - usage.cache_read_tokens
293
+
294
+ # Costs are per 1M tokens
295
+ input_cost = (non_cached_prompt_tokens / 1_000_000) * cost_input
296
+ output_cost = (usage.completion_tokens / 1_000_000) * cost_output
297
+
298
+ return input_cost + output_cost
299
+
300
+ def _calculate_cache_cost(
301
+ self,
302
+ cache_creation_tokens: int,
303
+ cache_read_tokens: int,
304
+ model: str
305
+ ) -> float:
306
+ """
307
+ Calculate cost for cached tokens.
308
+
309
+ Args:
310
+ cache_creation_tokens: Number of tokens written to cache
311
+ cache_read_tokens: Number of tokens read from cache
312
+ model: Model name used
313
+
314
+ Returns:
315
+ Cost in USD for cache operations
316
+ """
317
+ model_info = ANTHROPIC_MODELS.get(model, {})
318
+
319
+ # Check if model supports caching
320
+ if not model_info.get("supports_caching", False):
321
+ return 0.0
322
+
323
+ cost_cache_write = model_info.get("cost_cache_write", 0.0)
324
+ cost_cache_read = model_info.get("cost_cache_read", 0.0)
325
+
326
+ # Costs are per 1M tokens
327
+ write_cost = (cache_creation_tokens / 1_000_000) * cost_cache_write
328
+ read_cost = (cache_read_tokens / 1_000_000) * cost_cache_read
329
+
330
+ return write_cost + read_cost
@@ -0,0 +1,183 @@
1
+ """Abstract base class for LLM providers."""
2
+
3
+ import asyncio
4
+ from abc import ABC, abstractmethod
5
+ from typing import AsyncIterator, List, Optional
6
+
7
+ from ..models import ChatRequest, ChatResponse, Usage
8
+ from ..exceptions import ValidationError
9
+
10
+
11
+ class BaseProvider(ABC):
12
+ """Abstract base class that all LLM providers must implement."""
13
+
14
+ def __init__(self, api_key: str, config: dict = None):
15
+ """
16
+ Initialize provider with API key and optional configuration.
17
+
18
+ Args:
19
+ api_key: Provider API key
20
+ config: Optional provider-specific configuration
21
+ """
22
+ self.api_key = api_key
23
+ self.config = config or {}
24
+ self._client = None
25
+
26
+ @abstractmethod
27
+ def _initialize_client(self) -> None:
28
+ """Initialize the provider-specific client library."""
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def chat_completion(self, request: ChatRequest) -> ChatResponse:
33
+ """
34
+ Execute a chat completion request.
35
+
36
+ Args:
37
+ request: Unified chat request
38
+
39
+ Returns:
40
+ Unified chat response
41
+
42
+ Raises:
43
+ InvalidModelError: If model not supported
44
+ ProviderAPIError: If API call fails
45
+ """
46
+ pass
47
+
48
+ @abstractmethod
49
+ async def chat_completion_stream(
50
+ self, request: ChatRequest
51
+ ) -> AsyncIterator[ChatResponse]:
52
+ """
53
+ Execute a streaming chat completion request.
54
+
55
+ Args:
56
+ request: Unified chat request with stream=True
57
+
58
+ Yields:
59
+ Unified chat response chunks
60
+
61
+ Raises:
62
+ InvalidModelError: If model not supported
63
+ ProviderAPIError: If API call fails
64
+ """
65
+ pass
66
+
67
+ def chat_completion_sync(self, request: ChatRequest) -> ChatResponse:
68
+ """
69
+ Synchronous wrapper for chat_completion.
70
+
71
+ Args:
72
+ request: Unified chat request
73
+
74
+ Returns:
75
+ Unified chat response
76
+ """
77
+ return asyncio.run(self.chat_completion(request))
78
+
79
+ @abstractmethod
80
+ def _normalize_response(self, raw_response: dict) -> ChatResponse:
81
+ """
82
+ Convert provider-specific response to unified format.
83
+
84
+ Args:
85
+ raw_response: Raw response from provider API
86
+
87
+ Returns:
88
+ Normalized ChatResponse
89
+ """
90
+ pass
91
+
92
+ @abstractmethod
93
+ def _calculate_cost(self, usage: Usage, model: str) -> float:
94
+ """
95
+ Calculate cost for the request based on token usage.
96
+
97
+ Args:
98
+ usage: Token usage information
99
+ model: Model name used
100
+
101
+ Returns:
102
+ Cost in USD
103
+ """
104
+ pass
105
+
106
+ @property
107
+ @abstractmethod
108
+ def provider_name(self) -> str:
109
+ """Return the provider name (e.g., 'openai', 'anthropic')."""
110
+ pass
111
+
112
+ @abstractmethod
113
+ def get_supported_models(self) -> List[str]:
114
+ """
115
+ Return list of models supported by this provider.
116
+
117
+ Returns:
118
+ List of model names
119
+ """
120
+ pass
121
+
122
+ def validate_model(self, model: str) -> bool:
123
+ """
124
+ Check if model is supported by this provider.
125
+
126
+ Args:
127
+ model: Model name to validate
128
+
129
+ Returns:
130
+ True if supported, False otherwise
131
+ """
132
+ return model in self.get_supported_models()
133
+
134
+ def supports_caching(self, model: str) -> bool:
135
+ """
136
+ Check if model supports prompt caching.
137
+
138
+ Args:
139
+ model: Model name to check
140
+
141
+ Returns:
142
+ True if model supports prompt caching, False otherwise
143
+ """
144
+ # To be implemented by providers that support caching
145
+ return False
146
+
147
+ def _calculate_cache_cost(
148
+ self,
149
+ cache_creation_tokens: int,
150
+ cache_read_tokens: int,
151
+ model: str
152
+ ) -> float:
153
+ """
154
+ Calculate cost for cached tokens.
155
+
156
+ Args:
157
+ cache_creation_tokens: Number of tokens written to cache
158
+ cache_read_tokens: Number of tokens read from cache
159
+ model: Model name used
160
+
161
+ Returns:
162
+ Cost in USD for cache operations
163
+ """
164
+ # Base implementation returns 0, override in providers that support caching
165
+ return 0.0
166
+
167
+ def validate_temperature(self, temperature: float, min_temp: float = 0.0, max_temp: float = 2.0) -> None:
168
+ """
169
+ Validate temperature parameter is within provider constraints.
170
+
171
+ Args:
172
+ temperature: Temperature value to validate
173
+ min_temp: Minimum allowed temperature (provider-specific)
174
+ max_temp: Maximum allowed temperature (provider-specific)
175
+
176
+ Raises:
177
+ ValidationError: If temperature is out of range
178
+ """
179
+ if not (min_temp <= temperature <= max_temp):
180
+ raise ValidationError(
181
+ f"{self.provider_name} temperature must be between {min_temp} and {max_temp}, "
182
+ f"got {temperature}"
183
+ )