stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,39 @@
1
+ """Google Gemini provider implementation."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import GOOGLE_MODELS, PROVIDER_BASE_URLS
7
+ from ..exceptions import AuthenticationError
8
+ from .openai_compatible import OpenAICompatibleProvider
9
+
10
+
11
+ class GoogleProvider(OpenAICompatibleProvider):
12
+ """Google Gemini provider using OpenAI-compatible API."""
13
+
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ config: dict = None
18
+ ):
19
+ """
20
+ Initialize Google Gemini provider.
21
+
22
+ Args:
23
+ api_key: Google API key (defaults to GOOGLE_API_KEY env var)
24
+ config: Optional provider-specific configuration
25
+
26
+ Raises:
27
+ AuthenticationError: If API key not provided
28
+ """
29
+ api_key = api_key or os.getenv("GOOGLE_API_KEY")
30
+ if not api_key:
31
+ raise AuthenticationError("google")
32
+
33
+ base_url = PROVIDER_BASE_URLS["google"]
34
+ super().__init__(api_key, base_url, GOOGLE_MODELS, config)
35
+
36
+ @property
37
+ def provider_name(self) -> str:
38
+ """Return provider name."""
39
+ return "google"
@@ -0,0 +1,39 @@
1
+ """Grok (X.AI) provider implementation."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import GROK_MODELS, PROVIDER_BASE_URLS
7
+ from ..exceptions import AuthenticationError
8
+ from .openai_compatible import OpenAICompatibleProvider
9
+
10
+
11
+ class GrokProvider(OpenAICompatibleProvider):
12
+ """Grok (X.AI) provider using OpenAI-compatible API."""
13
+
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ config: dict = None
18
+ ):
19
+ """
20
+ Initialize Grok provider.
21
+
22
+ Args:
23
+ api_key: Grok API key (defaults to GROK_API_KEY env var)
24
+ config: Optional provider-specific configuration
25
+
26
+ Raises:
27
+ AuthenticationError: If API key not provided
28
+ """
29
+ api_key = api_key or os.getenv("GROK_API_KEY")
30
+ if not api_key:
31
+ raise AuthenticationError("grok")
32
+
33
+ base_url = PROVIDER_BASE_URLS["grok"]
34
+ super().__init__(api_key, base_url, GROK_MODELS, config)
35
+
36
+ @property
37
+ def provider_name(self) -> str:
38
+ """Return provider name."""
39
+ return "grok"
@@ -0,0 +1,39 @@
1
+ """Groq provider implementation."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import GROQ_MODELS, PROVIDER_BASE_URLS
7
+ from ..exceptions import AuthenticationError
8
+ from .openai_compatible import OpenAICompatibleProvider
9
+
10
+
11
+ class GroqProvider(OpenAICompatibleProvider):
12
+ """Groq provider using OpenAI-compatible API."""
13
+
14
+ def __init__(
15
+ self,
16
+ api_key: Optional[str] = None,
17
+ config: dict = None
18
+ ):
19
+ """
20
+ Initialize Groq provider.
21
+
22
+ Args:
23
+ api_key: Groq API key (defaults to GROQ_API_KEY env var)
24
+ config: Optional provider-specific configuration
25
+
26
+ Raises:
27
+ AuthenticationError: If API key not provided
28
+ """
29
+ api_key = api_key or os.getenv("GROQ_API_KEY")
30
+ if not api_key:
31
+ raise AuthenticationError("groq")
32
+
33
+ base_url = PROVIDER_BASE_URLS["groq"]
34
+ super().__init__(api_key, base_url, GROQ_MODELS, config)
35
+
36
+ @property
37
+ def provider_name(self) -> str:
38
+ """Return provider name."""
39
+ return "groq"
@@ -0,0 +1,43 @@
1
+ """Ollama provider implementation for local models."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from ..config import OLLAMA_MODELS, PROVIDER_BASE_URLS
7
+ from .openai_compatible import OpenAICompatibleProvider
8
+
9
+
10
+ class OllamaProvider(OpenAICompatibleProvider):
11
+ """Ollama provider for local models using OpenAI-compatible API."""
12
+
13
+ def __init__(
14
+ self,
15
+ api_key: Optional[str] = None,
16
+ config: dict = None
17
+ ):
18
+ """
19
+ Initialize Ollama provider.
20
+
21
+ Args:
22
+ api_key: Optional API key (Ollama typically doesn't require one)
23
+ config: Optional provider-specific configuration (can include base_url)
24
+
25
+ Note:
26
+ Ollama runs locally and typically doesn't require an API key.
27
+ Default base URL is http://localhost:11434/v1
28
+ """
29
+ # Ollama doesn't require an API key, use placeholder
30
+ from ..api_key_helper import APIKeyHelper
31
+ api_key = APIKeyHelper.get_api_key("ollama", api_key) or "ollama"
32
+
33
+ # Allow custom base URL via config
34
+ base_url = PROVIDER_BASE_URLS["ollama"]
35
+ if config and "base_url" in config:
36
+ base_url = config["base_url"]
37
+
38
+ super().__init__(api_key, base_url, OLLAMA_MODELS, config)
39
+
40
+ @property
41
+ def provider_name(self) -> str:
42
+ """Return provider name."""
43
+ return "ollama"
@@ -0,0 +1,344 @@
1
+ """OpenAI provider implementation."""
2
+
3
+ import os
4
+ from datetime import datetime
5
+ from typing import AsyncIterator, List, Optional
6
+
7
+ from openai import AsyncOpenAI
8
+
9
+ from ..config import OPENAI_MODELS, PROVIDER_CONSTRAINTS
10
+ from ..exceptions import AuthenticationError, InvalidModelError, ProviderAPIError
11
+ from ..models import ChatRequest, ChatResponse, Usage
12
+ from .base import BaseProvider
13
+
14
+
15
+ class OpenAIProvider(BaseProvider):
16
+ """OpenAI provider implementation with cost tracking."""
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: Optional[str] = None,
21
+ config: dict = None
22
+ ):
23
+ """
24
+ Initialize OpenAI provider.
25
+
26
+ Args:
27
+ api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
28
+ config: Optional provider-specific configuration
29
+
30
+ Raises:
31
+ ValueError: If API key not provided (with helpful setup instructions)
32
+ """
33
+ from ..api_key_helper import get_api_key_or_error
34
+ api_key = get_api_key_or_error("openai", api_key)
35
+ super().__init__(api_key, config)
36
+ self._initialize_client()
37
+
38
+ def _initialize_client(self) -> None:
39
+ """Initialize OpenAI async client."""
40
+ try:
41
+ self._client = AsyncOpenAI(api_key=self.api_key)
42
+ except Exception as e:
43
+ raise ProviderAPIError(
44
+ f"Failed to initialize OpenAI client: {str(e)}",
45
+ "openai"
46
+ )
47
+
48
+ @property
49
+ def provider_name(self) -> str:
50
+ """Return provider name."""
51
+ return "openai"
52
+
53
+ def get_supported_models(self) -> List[str]:
54
+ """Return list of supported OpenAI models."""
55
+ return list(OPENAI_MODELS.keys())
56
+
57
+ def supports_caching(self, model: str) -> bool:
58
+ """Check if model supports prompt caching."""
59
+ model_info = OPENAI_MODELS.get(model, {})
60
+ return model_info.get("supports_caching", False)
61
+
62
+ async def chat_completion(self, request: ChatRequest) -> ChatResponse:
63
+ """
64
+ Execute chat completion request.
65
+
66
+ Args:
67
+ request: Unified chat request
68
+
69
+ Returns:
70
+ Unified chat response with cost tracking
71
+
72
+ Raises:
73
+ InvalidModelError: If model not supported
74
+ ProviderAPIError: If API call fails
75
+ """
76
+ if not self.validate_model(request.model):
77
+ raise InvalidModelError(request.model, self.provider_name)
78
+
79
+ # Validate temperature constraints for OpenAI (0.0 to 2.0)
80
+ constraints = PROVIDER_CONSTRAINTS.get(self.provider_name, {})
81
+ self.validate_temperature(
82
+ request.temperature,
83
+ constraints.get("min_temperature", 0.0),
84
+ constraints.get("max_temperature", 2.0)
85
+ )
86
+
87
+ # Build OpenAI-specific request parameters
88
+ messages = []
89
+ for msg in request.messages:
90
+ message_dict = {"role": msg.role, "content": msg.content}
91
+ # Add cache_control if present and model supports caching
92
+ if msg.cache_control and self.supports_caching(request.model):
93
+ message_dict["cache_control"] = msg.cache_control
94
+ messages.append(message_dict)
95
+
96
+ openai_params = {
97
+ "model": request.model,
98
+ "messages": messages,
99
+ }
100
+
101
+ # Check if model is a reasoning model (o-series)
102
+ model_info = OPENAI_MODELS.get(request.model, {})
103
+ is_reasoning_model = model_info.get("reasoning_model", False)
104
+
105
+ # Also check if model name starts with o1, o3, gpt-5, or just 'o' followed by a digit
106
+ # This catches variants like o1-preview, o1-2024-12-17, o3-mini, gpt-5, etc.
107
+ if not is_reasoning_model and request.model:
108
+ model_lower = request.model.lower()
109
+ # Match: o1*, o3*, gpt-5*, "reasoning", or o followed by digit
110
+ is_reasoning_model = (
111
+ model_lower.startswith("o1") or
112
+ model_lower.startswith("o3") or
113
+ model_lower.startswith("gpt-5") or
114
+ "reasoning" in model_lower or
115
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
116
+ )
117
+
118
+ # Only add these parameters for non-reasoning models
119
+ # Reasoning models like o1, o1-mini, o3-mini don't support temperature/top_p/penalties
120
+ if not is_reasoning_model:
121
+ openai_params["temperature"] = request.temperature
122
+ openai_params["top_p"] = request.top_p
123
+ openai_params["frequency_penalty"] = request.frequency_penalty
124
+ openai_params["presence_penalty"] = request.presence_penalty
125
+
126
+ # Add optional parameters
127
+ if request.max_tokens:
128
+ openai_params["max_tokens"] = request.max_tokens
129
+ if request.stop:
130
+ openai_params["stop"] = request.stop
131
+
132
+ # Add reasoning_effort for o-series models
133
+ if request.reasoning_effort and "o" in request.model:
134
+ openai_params["reasoning_effort"] = request.reasoning_effort
135
+
136
+ # Add any extra params
137
+ if request.extra_params:
138
+ openai_params.update(request.extra_params)
139
+
140
+ try:
141
+ # Make API request
142
+ raw_response = await self._client.chat.completions.create(**openai_params)
143
+ # Normalize and return
144
+ return self._normalize_response(raw_response.model_dump())
145
+ except Exception as e:
146
+ raise ProviderAPIError(
147
+ f"Chat completion failed: {str(e)}",
148
+ self.provider_name
149
+ )
150
+
151
+ async def chat_completion_stream(
152
+ self, request: ChatRequest
153
+ ) -> AsyncIterator[ChatResponse]:
154
+ """
155
+ Execute streaming chat completion request.
156
+
157
+ Args:
158
+ request: Unified chat request
159
+
160
+ Yields:
161
+ Unified chat response chunks
162
+
163
+ Raises:
164
+ InvalidModelError: If model not supported
165
+ ProviderAPIError: If API call fails
166
+ """
167
+ if not self.validate_model(request.model):
168
+ raise InvalidModelError(request.model, self.provider_name)
169
+
170
+ # Build request parameters
171
+ openai_params = {
172
+ "model": request.model,
173
+ "messages": [
174
+ {"role": msg.role, "content": msg.content}
175
+ for msg in request.messages
176
+ ],
177
+ "stream": True,
178
+ }
179
+
180
+ # Check if model is a reasoning model
181
+ model_info = OPENAI_MODELS.get(request.model, {})
182
+ is_reasoning_model = model_info.get("reasoning_model", False)
183
+
184
+ # Also check if model name starts with o1, o3, gpt-5, or just 'o' followed by a digit
185
+ if not is_reasoning_model and request.model:
186
+ model_lower = request.model.lower()
187
+ is_reasoning_model = (
188
+ model_lower.startswith("o1") or
189
+ model_lower.startswith("o3") or
190
+ model_lower.startswith("gpt-5") or
191
+ "reasoning" in model_lower or
192
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
193
+ )
194
+
195
+ # Only add temperature for non-reasoning models
196
+ if not is_reasoning_model:
197
+ openai_params["temperature"] = request.temperature
198
+
199
+ if request.max_tokens:
200
+ openai_params["max_tokens"] = request.max_tokens
201
+
202
+ try:
203
+ stream = await self._client.chat.completions.create(**openai_params)
204
+
205
+ async for chunk in stream:
206
+ chunk_dict = chunk.model_dump()
207
+ if chunk.choices and chunk.choices[0].delta.content:
208
+ yield self._normalize_stream_chunk(chunk_dict)
209
+ except Exception as e:
210
+ raise ProviderAPIError(
211
+ f"Streaming chat completion failed: {str(e)}",
212
+ self.provider_name
213
+ )
214
+
215
+ def _normalize_response(self, raw_response: dict) -> ChatResponse:
216
+ """
217
+ Convert OpenAI response to unified format.
218
+
219
+ Args:
220
+ raw_response: Raw OpenAI API response
221
+
222
+ Returns:
223
+ Normalized ChatResponse with cost
224
+ """
225
+ choice = raw_response["choices"][0]
226
+ usage_dict = raw_response.get("usage", {})
227
+
228
+ # Extract token usage
229
+ prompt_details = usage_dict.get("prompt_tokens_details", {})
230
+ usage = Usage(
231
+ prompt_tokens=usage_dict.get("prompt_tokens", 0),
232
+ completion_tokens=usage_dict.get("completion_tokens", 0),
233
+ total_tokens=usage_dict.get("total_tokens", 0),
234
+ cached_tokens=prompt_details.get("cached_tokens", 0),
235
+ cache_creation_tokens=prompt_details.get("cache_creation_input_tokens", 0),
236
+ cache_read_tokens=prompt_details.get("cached_tokens", 0),
237
+ reasoning_tokens=usage_dict.get("completion_tokens_details", {}).get(
238
+ "reasoning_tokens", 0
239
+ ),
240
+ )
241
+
242
+ # Calculate cost including cache costs
243
+ base_cost = self._calculate_cost(usage, raw_response["model"])
244
+ cache_cost = self._calculate_cache_cost(
245
+ usage.cache_creation_tokens,
246
+ usage.cache_read_tokens,
247
+ raw_response["model"]
248
+ )
249
+ usage.cost_usd = base_cost + cache_cost
250
+
251
+ # Add cost breakdown
252
+ if usage.cache_creation_tokens > 0 or usage.cache_read_tokens > 0:
253
+ usage.cost_breakdown = {
254
+ "base_cost": base_cost,
255
+ "cache_cost": cache_cost,
256
+ "total_cost": usage.cost_usd,
257
+ }
258
+
259
+ return ChatResponse(
260
+ id=raw_response["id"],
261
+ model=raw_response["model"],
262
+ content=choice["message"]["content"] or "",
263
+ finish_reason=choice["finish_reason"],
264
+ usage=usage,
265
+ provider=self.provider_name,
266
+ created_at=datetime.fromtimestamp(raw_response["created"]),
267
+ raw_response=raw_response,
268
+ )
269
+
270
+ def _normalize_stream_chunk(self, chunk_dict: dict) -> ChatResponse:
271
+ """Normalize streaming chunk to ChatResponse format."""
272
+ choice = chunk_dict["choices"][0]
273
+ content = choice["delta"].get("content", "")
274
+
275
+ return ChatResponse(
276
+ id=chunk_dict["id"],
277
+ model=chunk_dict["model"],
278
+ content=content,
279
+ finish_reason=choice.get("finish_reason", ""),
280
+ usage=Usage(
281
+ prompt_tokens=0,
282
+ completion_tokens=0,
283
+ total_tokens=0
284
+ ),
285
+ provider=self.provider_name,
286
+ created_at=datetime.fromtimestamp(chunk_dict["created"]),
287
+ raw_response=chunk_dict,
288
+ )
289
+
290
+ def _calculate_cost(self, usage: Usage, model: str) -> float:
291
+ """
292
+ Calculate cost in USD based on token usage (excluding cache costs).
293
+
294
+ Args:
295
+ usage: Token usage information
296
+ model: Model name used
297
+
298
+ Returns:
299
+ Cost in USD
300
+ """
301
+ model_info = OPENAI_MODELS.get(model, {})
302
+ cost_input = model_info.get("cost_input", 0.0)
303
+ cost_output = model_info.get("cost_output", 0.0)
304
+
305
+ # Calculate non-cached prompt tokens
306
+ non_cached_prompt_tokens = usage.prompt_tokens - usage.cache_read_tokens
307
+
308
+ # Costs are per 1M tokens
309
+ input_cost = (non_cached_prompt_tokens / 1_000_000) * cost_input
310
+ output_cost = (usage.completion_tokens / 1_000_000) * cost_output
311
+
312
+ return input_cost + output_cost
313
+
314
+ def _calculate_cache_cost(
315
+ self,
316
+ cache_creation_tokens: int,
317
+ cache_read_tokens: int,
318
+ model: str
319
+ ) -> float:
320
+ """
321
+ Calculate cost for cached tokens.
322
+
323
+ Args:
324
+ cache_creation_tokens: Number of tokens written to cache
325
+ cache_read_tokens: Number of tokens read from cache
326
+ model: Model name used
327
+
328
+ Returns:
329
+ Cost in USD for cache operations
330
+ """
331
+ model_info = OPENAI_MODELS.get(model, {})
332
+
333
+ # Check if model supports caching
334
+ if not model_info.get("supports_caching", False):
335
+ return 0.0
336
+
337
+ cost_cache_write = model_info.get("cost_cache_write", 0.0)
338
+ cost_cache_read = model_info.get("cost_cache_read", 0.0)
339
+
340
+ # Costs are per 1M tokens
341
+ write_cost = (cache_creation_tokens / 1_000_000) * cost_cache_write
342
+ read_cost = (cache_read_tokens / 1_000_000) * cost_cache_read
343
+
344
+ return write_cost + read_cost