stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,372 @@
1
+ """Base class for OpenAI-compatible providers."""
2
+
3
+ from datetime import datetime
4
+ from typing import AsyncIterator, Dict, List
5
+
6
+ from openai import AsyncOpenAI, APIStatusError, APIError
7
+
8
+ from ..config import PROVIDER_CONSTRAINTS
9
+ from ..exceptions import ProviderAPIError, InvalidModelError, InsufficientBalanceError, AuthenticationError
10
+ from ..models import ChatRequest, ChatResponse, Usage
11
+ from .base import BaseProvider
12
+
13
+
14
+ class OpenAICompatibleProvider(BaseProvider):
15
+ """
16
+ Base class for providers with OpenAI-compatible APIs.
17
+
18
+ This includes: Google Gemini, DeepSeek, Groq, Grok, OpenRouter, Ollama.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ api_key: str,
24
+ base_url: str,
25
+ model_catalog: Dict,
26
+ config: dict = None
27
+ ):
28
+ """
29
+ Initialize OpenAI-compatible provider.
30
+
31
+ Args:
32
+ api_key: Provider API key
33
+ base_url: Provider base URL
34
+ model_catalog: Model catalog for this provider
35
+ config: Optional provider-specific configuration
36
+ """
37
+ super().__init__(api_key, config)
38
+ self.base_url = base_url
39
+ self.model_catalog = model_catalog
40
+ self._initialize_client()
41
+
42
+ def _initialize_client(self) -> None:
43
+ """Initialize OpenAI-compatible async client."""
44
+ try:
45
+ self._client = AsyncOpenAI(
46
+ api_key=self.api_key,
47
+ base_url=self.base_url
48
+ )
49
+ except Exception as e:
50
+ raise ProviderAPIError(
51
+ f"Failed to initialize {self.provider_name} client: {str(e)}",
52
+ self.provider_name
53
+ )
54
+
55
+ def get_supported_models(self) -> List[str]:
56
+ """Return list of supported models."""
57
+ return list(self.model_catalog.keys())
58
+
59
+ def supports_caching(self, model: str) -> bool:
60
+ """Check if model supports prompt caching."""
61
+ model_info = self.model_catalog.get(model, {})
62
+ return model_info.get("supports_caching", False)
63
+
64
+ async def chat_completion(self, request: ChatRequest) -> ChatResponse:
65
+ """
66
+ Execute chat completion request.
67
+
68
+ Args:
69
+ request: Unified chat request
70
+
71
+ Returns:
72
+ Unified chat response with cost tracking
73
+
74
+ Raises:
75
+ InvalidModelError: If model not supported
76
+ ProviderAPIError: If API call fails
77
+ """
78
+ if not self.validate_model(request.model):
79
+ raise InvalidModelError(request.model, self.provider_name)
80
+
81
+ # Validate temperature constraints (most OpenAI-compatible providers use 0.0 to 2.0)
82
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
83
+ self.validate_temperature(
84
+ request.temperature,
85
+ constraints.get("min_temperature", 0.0),
86
+ constraints.get("max_temperature", 2.0)
87
+ )
88
+
89
+ # Build OpenAI-compatible request parameters
90
+ messages = []
91
+ for msg in request.messages:
92
+ message_dict = {"role": msg.role, "content": msg.content}
93
+ # Add cache_control if present and model supports caching
94
+ if msg.cache_control and self.supports_caching(request.model):
95
+ message_dict["cache_control"] = msg.cache_control
96
+ messages.append(message_dict)
97
+
98
+ openai_params = {
99
+ "model": request.model,
100
+ "messages": messages,
101
+ }
102
+
103
+ # Check if model is a reasoning model
104
+ model_info = self.model_catalog.get(request.model, {})
105
+ is_reasoning_model = model_info.get("reasoning_model", False)
106
+
107
+ # Also check model name patterns for reasoning models
108
+ if not is_reasoning_model and request.model:
109
+ model_lower = request.model.lower()
110
+ is_reasoning_model = (
111
+ model_lower.startswith("o1") or
112
+ model_lower.startswith("o3") or
113
+ model_lower.startswith("gpt-5") or
114
+ "reasoner" in model_lower or
115
+ "reasoning" in model_lower or
116
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
117
+ )
118
+
119
+ # Only add temperature and sampling params for non-reasoning models
120
+ if not is_reasoning_model:
121
+ openai_params["temperature"] = request.temperature
122
+ openai_params["top_p"] = request.top_p
123
+ if request.frequency_penalty:
124
+ openai_params["frequency_penalty"] = request.frequency_penalty
125
+ if request.presence_penalty:
126
+ openai_params["presence_penalty"] = request.presence_penalty
127
+
128
+ # Add optional parameters
129
+ if request.max_tokens:
130
+ openai_params["max_tokens"] = request.max_tokens
131
+ if request.stop:
132
+ openai_params["stop"] = request.stop
133
+
134
+ # Add any extra params
135
+ if request.extra_params:
136
+ openai_params.update(request.extra_params)
137
+
138
+ try:
139
+ # Make API request
140
+ raw_response = await self._client.chat.completions.create(**openai_params)
141
+ # Normalize and return
142
+ return self._normalize_response(raw_response.model_dump())
143
+ except (APIStatusError, APIError) as e:
144
+ error_msg = str(e)
145
+ # Check for specific error types
146
+ if "insufficient balance" in error_msg.lower():
147
+ raise InsufficientBalanceError(self.provider_name)
148
+ elif "invalid_api_key" in error_msg.lower() or "unauthorized" in error_msg.lower() or (hasattr(e, 'status_code') and e.status_code == 401):
149
+ raise AuthenticationError(self.provider_name)
150
+ else:
151
+ raise ProviderAPIError(
152
+ f"Chat completion failed: {error_msg}",
153
+ self.provider_name
154
+ )
155
+ except Exception as e:
156
+ raise ProviderAPIError(
157
+ f"Chat completion failed: {str(e)}",
158
+ self.provider_name
159
+ )
160
+
161
+ async def chat_completion_stream(
162
+ self, request: ChatRequest
163
+ ) -> AsyncIterator[ChatResponse]:
164
+ """
165
+ Execute streaming chat completion request.
166
+
167
+ Args:
168
+ request: Unified chat request
169
+
170
+ Yields:
171
+ Unified chat response chunks
172
+
173
+ Raises:
174
+ InvalidModelError: If model not supported
175
+ ProviderAPIError: If API call fails
176
+ """
177
+ if not self.validate_model(request.model):
178
+ raise InvalidModelError(request.model, self.provider_name)
179
+
180
+ # Validate temperature constraints (most OpenAI-compatible providers use 0.0 to 2.0)
181
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
182
+ self.validate_temperature(
183
+ request.temperature,
184
+ constraints.get("min_temperature", 0.0),
185
+ constraints.get("max_temperature", 2.0)
186
+ )
187
+
188
+ # Build request parameters
189
+ openai_params = {
190
+ "model": request.model,
191
+ "messages": [
192
+ {"role": msg.role, "content": msg.content}
193
+ for msg in request.messages
194
+ ],
195
+ "stream": True,
196
+ }
197
+
198
+ # Check if model is a reasoning model
199
+ model_info = self.model_catalog.get(request.model, {})
200
+ is_reasoning_model = model_info.get("reasoning_model", False)
201
+
202
+ # Also check model name patterns for reasoning models
203
+ if not is_reasoning_model and request.model:
204
+ model_lower = request.model.lower()
205
+ is_reasoning_model = (
206
+ model_lower.startswith("o1") or
207
+ model_lower.startswith("o3") or
208
+ model_lower.startswith("gpt-5") or
209
+ "reasoner" in model_lower or
210
+ "reasoning" in model_lower or
211
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
212
+ )
213
+
214
+ # Only add temperature for non-reasoning models
215
+ if not is_reasoning_model:
216
+ openai_params["temperature"] = request.temperature
217
+
218
+ if request.max_tokens:
219
+ openai_params["max_tokens"] = request.max_tokens
220
+
221
+ try:
222
+ stream = await self._client.chat.completions.create(**openai_params)
223
+
224
+ async for chunk in stream:
225
+ chunk_dict = chunk.model_dump()
226
+ if chunk.choices and chunk.choices[0].delta.content:
227
+ yield self._normalize_stream_chunk(chunk_dict)
228
+ except (APIStatusError, APIError) as e:
229
+ error_msg = str(e)
230
+ # Check for specific error types
231
+ if "insufficient balance" in error_msg.lower():
232
+ raise InsufficientBalanceError(self.provider_name)
233
+ elif "invalid_api_key" in error_msg.lower() or "unauthorized" in error_msg.lower() or (hasattr(e, 'status_code') and e.status_code == 401):
234
+ raise AuthenticationError(self.provider_name)
235
+ else:
236
+ raise ProviderAPIError(
237
+ f"Streaming chat completion failed: {error_msg}",
238
+ self.provider_name
239
+ )
240
+ except Exception as e:
241
+ raise ProviderAPIError(
242
+ f"Streaming chat completion failed: {str(e)}",
243
+ self.provider_name
244
+ )
245
+
246
+ def _normalize_response(self, raw_response: dict) -> ChatResponse:
247
+ """
248
+ Convert OpenAI-compatible response to unified format.
249
+
250
+ Args:
251
+ raw_response: Raw API response
252
+
253
+ Returns:
254
+ Normalized ChatResponse with cost
255
+ """
256
+ choice = raw_response["choices"][0]
257
+ usage_dict = raw_response.get("usage") or {}
258
+
259
+ # Extract token usage
260
+ prompt_details = usage_dict.get("prompt_tokens_details") or {}
261
+ usage = Usage(
262
+ prompt_tokens=usage_dict.get("prompt_tokens", 0),
263
+ completion_tokens=usage_dict.get("completion_tokens", 0),
264
+ total_tokens=usage_dict.get("total_tokens", 0),
265
+ cached_tokens=prompt_details.get("cached_tokens", 0),
266
+ cache_creation_tokens=prompt_details.get("cache_creation_input_tokens", 0),
267
+ cache_read_tokens=prompt_details.get("cached_tokens", 0),
268
+ )
269
+
270
+ # Calculate cost including cache costs
271
+ base_cost = self._calculate_cost(usage, raw_response["model"])
272
+ cache_cost = self._calculate_cache_cost(
273
+ usage.cache_creation_tokens,
274
+ usage.cache_read_tokens,
275
+ raw_response["model"]
276
+ )
277
+ usage.cost_usd = base_cost + cache_cost
278
+
279
+ # Add cost breakdown
280
+ if usage.cache_creation_tokens > 0 or usage.cache_read_tokens > 0:
281
+ usage.cost_breakdown = {
282
+ "base_cost": base_cost,
283
+ "cache_cost": cache_cost,
284
+ "total_cost": usage.cost_usd,
285
+ }
286
+
287
+ return ChatResponse(
288
+ id=raw_response.get("id", ""),
289
+ model=raw_response["model"],
290
+ content=choice["message"]["content"] or "",
291
+ finish_reason=choice["finish_reason"],
292
+ usage=usage,
293
+ provider=self.provider_name,
294
+ created_at=datetime.fromtimestamp(raw_response.get("created", 0)) if raw_response.get("created") else datetime.now(),
295
+ raw_response=raw_response,
296
+ )
297
+
298
+ def _normalize_stream_chunk(self, chunk_dict: dict) -> ChatResponse:
299
+ """Normalize streaming chunk to ChatResponse format."""
300
+ choice = chunk_dict["choices"][0]
301
+ content = choice["delta"].get("content", "")
302
+
303
+ return ChatResponse(
304
+ id=chunk_dict.get("id", ""),
305
+ model=chunk_dict["model"],
306
+ content=content,
307
+ finish_reason=choice.get("finish_reason", ""),
308
+ usage=Usage(
309
+ prompt_tokens=0,
310
+ completion_tokens=0,
311
+ total_tokens=0
312
+ ),
313
+ provider=self.provider_name,
314
+ created_at=datetime.fromtimestamp(chunk_dict.get("created", 0)) if chunk_dict.get("created") else datetime.now(),
315
+ raw_response=chunk_dict,
316
+ )
317
+
318
+ def _calculate_cost(self, usage: Usage, model: str) -> float:
319
+ """
320
+ Calculate cost in USD based on token usage (excluding cache costs).
321
+
322
+ Args:
323
+ usage: Token usage information
324
+ model: Model name used
325
+
326
+ Returns:
327
+ Cost in USD
328
+ """
329
+ model_info = self.model_catalog.get(model, {})
330
+ cost_input = model_info.get("cost_input", 0.0)
331
+ cost_output = model_info.get("cost_output", 0.0)
332
+
333
+ # Calculate non-cached prompt tokens
334
+ non_cached_prompt_tokens = usage.prompt_tokens - usage.cache_read_tokens
335
+
336
+ # Costs are per 1M tokens
337
+ input_cost = (non_cached_prompt_tokens / 1_000_000) * cost_input
338
+ output_cost = (usage.completion_tokens / 1_000_000) * cost_output
339
+
340
+ return input_cost + output_cost
341
+
342
+ def _calculate_cache_cost(
343
+ self,
344
+ cache_creation_tokens: int,
345
+ cache_read_tokens: int,
346
+ model: str
347
+ ) -> float:
348
+ """
349
+ Calculate cost for cached tokens.
350
+
351
+ Args:
352
+ cache_creation_tokens: Number of tokens written to cache
353
+ cache_read_tokens: Number of tokens read from cache
354
+ model: Model name used
355
+
356
+ Returns:
357
+ Cost in USD for cache operations
358
+ """
359
+ model_info = self.model_catalog.get(model, {})
360
+
361
+ # Check if model supports caching
362
+ if not model_info.get("supports_caching", False):
363
+ return 0.0
364
+
365
+ cost_cache_write = model_info.get("cost_cache_write", 0.0)
366
+ cost_cache_read = model_info.get("cost_cache_read", 0.0)
367
+
368
+ # Costs are per 1M tokens
369
+ write_cost = (cache_creation_tokens / 1_000_000) * cost_cache_write
370
+ read_cost = (cache_read_tokens / 1_000_000) * cost_cache_read
371
+
372
+ return write_cost + read_cost
@@ -0,0 +1,39 @@
1
+ """OpenRouter provider implementation."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import OPENROUTER_MODELS, PROVIDER_BASE_URLS
7
+ from ..exceptions import AuthenticationError
8
+ from .openai_compatible import OpenAICompatibleProvider
9
+
10
+
11
+ class OpenRouterProvider(OpenAICompatibleProvider):
12
+ """OpenRouter provider using OpenAI-compatible API."""
13
+
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ config: dict = None
18
+ ):
19
+ """
20
+ Initialize OpenRouter provider.
21
+
22
+ Args:
23
+ api_key: OpenRouter API key (defaults to OPENROUTER_API_KEY env var)
24
+ config: Optional provider-specific configuration
25
+
26
+ Raises:
27
+ AuthenticationError: If API key not provided
28
+ """
29
+ api_key = api_key or os.getenv("OPENROUTER_API_KEY")
30
+ if not api_key:
31
+ raise AuthenticationError("openrouter")
32
+
33
+ base_url = PROVIDER_BASE_URLS["openrouter"]
34
+ super().__init__(api_key, base_url, OPENROUTER_MODELS, config)
35
+
36
+ @property
37
+ def provider_name(self) -> str:
38
+ """Return provider name."""
39
+ return "openrouter"
stratifyai/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561
2
+ # This package supports type hints