stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
stratifyai/rag.py ADDED
@@ -0,0 +1,381 @@
1
+ """RAG (Retrieval-Augmented Generation) pipeline implementation.
2
+
3
+ This module provides a complete RAG pipeline that integrates:
4
+ - Document chunking and indexing
5
+ - Embedding generation
6
+ - Vector database storage
7
+ - Semantic search
8
+ - LLM-based response generation with citations
9
+ """
10
+
11
+ import asyncio
12
+ from dataclasses import dataclass
13
+ from typing import List, Optional, Dict, Any
14
+ from pathlib import Path
15
+ import os
16
+
17
+ from .embeddings import EmbeddingProvider, create_embedding_provider
18
+ from .vectordb import VectorDBClient, SearchResult
19
+ from .client import LLMClient
20
+ from .models import ChatRequest, Message
21
+ from .chunking import chunk_content, get_chunk_metadata
22
+ from .exceptions import LLMAbstractionError
23
+
24
+
25
+ @dataclass
26
+ class IndexingResult:
27
+ """Result of indexing operation.
28
+
29
+ Attributes:
30
+ collection_name: Name of the collection
31
+ num_chunks: Number of chunks indexed
32
+ num_files: Number of files indexed
33
+ total_tokens: Total tokens processed for embeddings
34
+ embedding_cost: Cost of embedding generation
35
+ """
36
+ collection_name: str
37
+ num_chunks: int
38
+ num_files: int
39
+ total_tokens: int
40
+ embedding_cost: float
41
+
42
+
43
+ @dataclass
44
+ class RAGResponse:
45
+ """Response from RAG query.
46
+
47
+ Attributes:
48
+ content: Generated response text
49
+ sources: List of source documents used
50
+ model: LLM model used for generation
51
+ total_cost: Total cost (embeddings + generation)
52
+ num_chunks_retrieved: Number of chunks retrieved
53
+ """
54
+ content: str
55
+ sources: List[Dict[str, Any]]
56
+ model: str
57
+ total_cost: float
58
+ num_chunks_retrieved: int
59
+
60
+
61
+ class RAGClient:
62
+ """Client for RAG (Retrieval-Augmented Generation) operations.
63
+
64
+ Provides a complete pipeline for:
65
+ 1. Indexing documents into vector database
66
+ 2. Querying with semantic search
67
+ 3. Generating responses with LLM using retrieved context
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ embedding_provider: Optional[EmbeddingProvider] = None,
73
+ llm_client: Optional[LLMClient] = None,
74
+ persist_directory: Optional[str] = None
75
+ ):
76
+ """Initialize RAG client.
77
+
78
+ Args:
79
+ embedding_provider: Provider for embeddings (default: OpenAI)
80
+ llm_client: LLM client for generation (default: creates new client)
81
+ persist_directory: Directory for vector database (default: ./chroma_db)
82
+ """
83
+ # Initialize embedding provider
84
+ if embedding_provider is None:
85
+ embedding_provider = create_embedding_provider("openai")
86
+ self.embedding_provider = embedding_provider
87
+
88
+ # Initialize vector database
89
+ self.vectordb = VectorDBClient(
90
+ embedding_provider=embedding_provider,
91
+ persist_directory=persist_directory
92
+ )
93
+
94
+ # Initialize LLM client
95
+ if llm_client is None:
96
+ llm_client = LLMClient()
97
+ self.llm_client = llm_client
98
+
99
+ async def index_file(
100
+ self,
101
+ file_path: str,
102
+ collection_name: str,
103
+ chunk_size: int = 1000,
104
+ overlap: int = 200
105
+ ) -> IndexingResult:
106
+ """Index a single file into the vector database.
107
+
108
+ Args:
109
+ file_path: Path to file to index
110
+ collection_name: Name of collection to store chunks
111
+ chunk_size: Size of each chunk in characters
112
+ overlap: Overlap between chunks in characters
113
+
114
+ Returns:
115
+ IndexingResult with statistics
116
+ """
117
+ # Read file
118
+ path = Path(file_path)
119
+ if not path.exists():
120
+ raise FileNotFoundError(f"File not found: {file_path}")
121
+
122
+ content = path.read_text(encoding="utf-8")
123
+
124
+ # Chunk content
125
+ chunks = chunk_content(
126
+ content,
127
+ chunk_size=chunk_size,
128
+ overlap=overlap,
129
+ preserve_boundaries=True
130
+ )
131
+
132
+ # Create metadata for each chunk
133
+ metadatas = [
134
+ {
135
+ "file": str(path.absolute()),
136
+ "filename": path.name,
137
+ "chunk_idx": i,
138
+ "total_chunks": len(chunks)
139
+ }
140
+ for i in range(len(chunks))
141
+ ]
142
+
143
+ # Add to vector database (this generates embeddings automatically)
144
+ # Run in thread pool since ChromaDB is sync
145
+ await asyncio.to_thread(
146
+ self.vectordb.add_documents,
147
+ collection_name=collection_name,
148
+ documents=chunks,
149
+ metadatas=metadatas
150
+ )
151
+
152
+ # Get embedding stats (estimate)
153
+ # Note: actual costs tracked by embedding provider during add_documents
154
+ chunk_metadata = get_chunk_metadata(chunks)
155
+
156
+ return IndexingResult(
157
+ collection_name=collection_name,
158
+ num_chunks=len(chunks),
159
+ num_files=1,
160
+ total_tokens=chunk_metadata["total_chars"] // 4, # Rough estimate
161
+ embedding_cost=0.0 # Would need to track from vectordb
162
+ )
163
+
164
+ async def index_directory(
165
+ self,
166
+ directory_path: str,
167
+ collection_name: str,
168
+ file_patterns: Optional[List[str]] = None,
169
+ chunk_size: int = 1000,
170
+ overlap: int = 200
171
+ ) -> IndexingResult:
172
+ """Index all files in a directory.
173
+
174
+ Args:
175
+ directory_path: Path to directory
176
+ collection_name: Name of collection to store chunks
177
+ file_patterns: List of file patterns (e.g., ["*.txt", "*.md"])
178
+ chunk_size: Size of each chunk in characters
179
+ overlap: Overlap between chunks in characters
180
+
181
+ Returns:
182
+ IndexingResult with statistics
183
+ """
184
+ dir_path = Path(directory_path)
185
+ if not dir_path.exists() or not dir_path.is_dir():
186
+ raise ValueError(f"Invalid directory: {directory_path}")
187
+
188
+ # Default patterns
189
+ if file_patterns is None:
190
+ file_patterns = ["*.txt", "*.md", "*.py", "*.js", "*.java", "*.cpp"]
191
+
192
+ # Find all matching files
193
+ files = []
194
+ for pattern in file_patterns:
195
+ files.extend(dir_path.rglob(pattern))
196
+
197
+ if not files:
198
+ raise LLMAbstractionError(
199
+ f"No files found matching patterns {file_patterns} in {directory_path}"
200
+ )
201
+
202
+ # Index each file
203
+ total_chunks = 0
204
+ total_tokens = 0
205
+
206
+ for file_path in files:
207
+ try:
208
+ result = await self.index_file(
209
+ file_path=str(file_path),
210
+ collection_name=collection_name,
211
+ chunk_size=chunk_size,
212
+ overlap=overlap
213
+ )
214
+ total_chunks += result.num_chunks
215
+ total_tokens += result.total_tokens
216
+ except Exception as e:
217
+ # Log error but continue with other files
218
+ print(f"Warning: Failed to index {file_path}: {str(e)}")
219
+
220
+ return IndexingResult(
221
+ collection_name=collection_name,
222
+ num_chunks=total_chunks,
223
+ num_files=len(files),
224
+ total_tokens=total_tokens,
225
+ embedding_cost=0.0
226
+ )
227
+
228
+ async def query(
229
+ self,
230
+ collection_name: str,
231
+ query: str,
232
+ provider: str = "openai",
233
+ model: str = "gpt-4o-mini",
234
+ n_results: int = 5,
235
+ include_sources: bool = True
236
+ ) -> RAGResponse:
237
+ """Query the RAG system.
238
+
239
+ Args:
240
+ collection_name: Collection to query
241
+ query: User query
242
+ provider: LLM provider for generation
243
+ model: LLM model for generation
244
+ n_results: Number of chunks to retrieve
245
+ include_sources: Whether to include source citations
246
+
247
+ Returns:
248
+ RAGResponse with generated content and sources
249
+ """
250
+ # Retrieve relevant chunks (run in thread pool since ChromaDB is sync)
251
+ search_results = await asyncio.to_thread(
252
+ self.vectordb.query,
253
+ collection_name=collection_name,
254
+ query_text=query,
255
+ n_results=n_results
256
+ )
257
+
258
+ if not search_results:
259
+ raise LLMAbstractionError(
260
+ f"No results found in collection '{collection_name}'. "
261
+ "Ensure the collection exists and has documents."
262
+ )
263
+
264
+ # Build context from retrieved chunks
265
+ context_parts = []
266
+ sources = []
267
+
268
+ for i, result in enumerate(search_results):
269
+ context_parts.append(f"[Source {i+1}]\n{result.document}")
270
+ sources.append({
271
+ "index": i + 1,
272
+ "file": result.metadata.get("filename", "unknown"),
273
+ "chunk_idx": result.metadata.get("chunk_idx", 0),
274
+ "similarity": 1.0 - result.distance # Convert distance to similarity
275
+ })
276
+
277
+ context = "\n\n".join(context_parts)
278
+
279
+ # Build prompt with context
280
+ prompt = f"""Answer the following question using ONLY the information provided in the sources below.
281
+ If the answer cannot be found in the sources, say "I cannot answer this based on the provided sources."
282
+
283
+ Sources:
284
+ {context}
285
+
286
+ Question: {query}
287
+
288
+ Answer:"""
289
+
290
+ # Generate response
291
+ messages = [Message(role="user", content=prompt)]
292
+ request = ChatRequest(
293
+ model=model,
294
+ messages=messages,
295
+ temperature=0.3 # Lower temperature for more factual responses
296
+ )
297
+
298
+ # Create client with specified provider for this request
299
+ client = LLMClient(provider=provider)
300
+ response = await client.chat_completion(request=request)
301
+
302
+ # Calculate total cost (embedding + generation)
303
+ total_cost = response.usage.cost_usd if response.usage else 0.0
304
+
305
+ return RAGResponse(
306
+ content=response.content,
307
+ sources=sources if include_sources else [],
308
+ model=response.model,
309
+ total_cost=total_cost,
310
+ num_chunks_retrieved=len(search_results)
311
+ )
312
+
313
+ def list_collections(self) -> List[str]:
314
+ """List all available collections.
315
+
316
+ Returns:
317
+ List of collection names
318
+ """
319
+ return self.vectordb.list_collections()
320
+
321
+ def delete_collection(self, collection_name: str) -> None:
322
+ """Delete a collection.
323
+
324
+ Args:
325
+ collection_name: Collection to delete
326
+ """
327
+ self.vectordb.delete_collection(collection_name)
328
+
329
+ def get_collection_stats(self, collection_name: str) -> Dict[str, Any]:
330
+ """Get statistics about a collection.
331
+
332
+ Args:
333
+ collection_name: Collection name
334
+
335
+ Returns:
336
+ Dictionary with collection statistics
337
+ """
338
+ count = self.vectordb.get_collection_count(collection_name)
339
+
340
+ # Get sample documents to extract file info
341
+ sample_docs = self.vectordb.get_documents(
342
+ collection_name=collection_name,
343
+ limit=100
344
+ )
345
+
346
+ # Extract unique files
347
+ files = set()
348
+ for doc in sample_docs:
349
+ if "filename" in doc["metadata"]:
350
+ files.add(doc["metadata"]["filename"])
351
+
352
+ return {
353
+ "name": collection_name,
354
+ "num_chunks": count,
355
+ "num_files": len(files),
356
+ "sample_files": list(files)[:10] # Show up to 10 files
357
+ }
358
+
359
+ def retrieve_only(
360
+ self,
361
+ collection_name: str,
362
+ query: str,
363
+ n_results: int = 5
364
+ ) -> List[SearchResult]:
365
+ """Retrieve relevant chunks without generating a response.
366
+
367
+ Useful for testing retrieval quality.
368
+
369
+ Args:
370
+ collection_name: Collection to query
371
+ query: Query text
372
+ n_results: Number of results to return
373
+
374
+ Returns:
375
+ List of SearchResult objects
376
+ """
377
+ return self.vectordb.query(
378
+ collection_name=collection_name,
379
+ query_text=query,
380
+ n_results=n_results
381
+ )
stratifyai/retry.py ADDED
@@ -0,0 +1,185 @@
1
+ """Retry logic with exponential backoff and fallback support."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Callable, List, Optional, Type, Tuple
6
+ from functools import wraps
7
+ from dataclasses import dataclass
8
+
9
+ from .exceptions import (
10
+ RateLimitError,
11
+ ProviderAPIError,
12
+ MaxRetriesExceededError,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class RetryConfig:
18
+ """Configuration for retry behavior."""
19
+
20
+ max_retries: int = 3
21
+ initial_delay: float = 1.0 # seconds
22
+ max_delay: float = 60.0 # seconds
23
+ exponential_base: float = 2.0
24
+ jitter: bool = True
25
+ retry_on_exceptions: Tuple[Type[Exception], ...] = (
26
+ RateLimitError,
27
+ ProviderAPIError,
28
+ )
29
+
30
+
31
+ def with_retry(
32
+ config: Optional[RetryConfig] = None,
33
+ fallback_models: Optional[List[str]] = None,
34
+ fallback_provider: Optional[str] = None,
35
+ ):
36
+ """
37
+ Decorator to add async retry logic with exponential backoff.
38
+
39
+ Args:
40
+ config: Retry configuration
41
+ fallback_models: List of fallback models to try if primary fails
42
+ fallback_provider: Fallback provider to use if primary fails
43
+
44
+ Usage:
45
+ @with_retry(config=RetryConfig(max_retries=5))
46
+ async def my_llm_call():
47
+ ...
48
+ """
49
+ if config is None:
50
+ config = RetryConfig()
51
+
52
+ def decorator(func: Callable) -> Callable:
53
+ @wraps(func)
54
+ async def async_wrapper(*args, **kwargs):
55
+ last_exception = None
56
+
57
+ for attempt in range(config.max_retries + 1):
58
+ try:
59
+ return await func(*args, **kwargs)
60
+ except config.retry_on_exceptions as e:
61
+ last_exception = e
62
+
63
+ if attempt == config.max_retries:
64
+ # Try fallbacks if configured
65
+ if fallback_models or fallback_provider:
66
+ logging.info(f"Attempting fallback after {attempt + 1} retries")
67
+ return await _try_fallback_async(
68
+ func,
69
+ args,
70
+ kwargs,
71
+ fallback_models,
72
+ fallback_provider,
73
+ last_exception,
74
+ )
75
+ raise MaxRetriesExceededError(
76
+ f"Max retries ({config.max_retries}) exceeded. Last error: {str(e)}"
77
+ )
78
+
79
+ # Calculate delay with exponential backoff
80
+ delay = min(
81
+ config.initial_delay * (config.exponential_base ** attempt),
82
+ config.max_delay
83
+ )
84
+
85
+ # Add jitter if enabled
86
+ if config.jitter:
87
+ import random
88
+ delay *= (0.5 + random.random())
89
+
90
+ logging.warning(
91
+ f"Retry attempt {attempt + 1}/{config.max_retries} "
92
+ f"after {delay:.2f}s delay. Error: {str(e)}"
93
+ )
94
+ await asyncio.sleep(delay)
95
+
96
+ # Should never reach here, but just in case
97
+ raise last_exception
98
+
99
+ return async_wrapper
100
+ return decorator
101
+
102
+
103
+ async def _try_fallback_async(
104
+ func: Callable,
105
+ args: tuple,
106
+ kwargs: dict,
107
+ fallback_models: Optional[List[str]],
108
+ fallback_provider: Optional[str],
109
+ original_error: Exception,
110
+ ):
111
+ """
112
+ Try fallback models or providers asynchronously.
113
+
114
+ Args:
115
+ func: Original async function
116
+ args: Function args
117
+ kwargs: Function kwargs
118
+ fallback_models: List of fallback model names
119
+ fallback_provider: Fallback provider name
120
+ original_error: Original exception that triggered fallback
121
+
122
+ Returns:
123
+ Result from successful fallback
124
+
125
+ Raises:
126
+ MaxRetriesExceededError: If all fallbacks fail
127
+ """
128
+ # Try fallback models first
129
+ if fallback_models:
130
+ for model in fallback_models:
131
+ try:
132
+ logging.info(f"Trying fallback model: {model}")
133
+ # Update model in kwargs if present
134
+ if 'request' in kwargs and hasattr(kwargs['request'], 'model'):
135
+ kwargs['request'].model = model
136
+ elif 'model' in kwargs:
137
+ kwargs['model'] = model
138
+ return await func(*args, **kwargs)
139
+ except Exception as e:
140
+ logging.warning(f"Fallback model {model} failed: {str(e)}")
141
+ continue
142
+
143
+ # Try fallback provider
144
+ if fallback_provider:
145
+ try:
146
+ logging.info(f"Trying fallback provider: {fallback_provider}")
147
+ if 'provider' in kwargs:
148
+ kwargs['provider'] = fallback_provider
149
+ return await func(*args, **kwargs)
150
+ except Exception as e:
151
+ logging.warning(f"Fallback provider {fallback_provider} failed: {str(e)}")
152
+
153
+ # All fallbacks failed
154
+ raise MaxRetriesExceededError(
155
+ f"All retries and fallbacks failed. Original error: {str(original_error)}"
156
+ )
157
+
158
+
159
+ def exponential_backoff(
160
+ attempt: int,
161
+ initial_delay: float = 1.0,
162
+ exponential_base: float = 2.0,
163
+ max_delay: float = 60.0,
164
+ jitter: bool = True,
165
+ ) -> float:
166
+ """
167
+ Calculate exponential backoff delay.
168
+
169
+ Args:
170
+ attempt: Current attempt number (0-indexed)
171
+ initial_delay: Initial delay in seconds
172
+ exponential_base: Base for exponential growth
173
+ max_delay: Maximum delay cap
174
+ jitter: Add random jitter to delay
175
+
176
+ Returns:
177
+ Delay in seconds
178
+ """
179
+ delay = min(initial_delay * (exponential_base ** attempt), max_delay)
180
+
181
+ if jitter:
182
+ import random
183
+ delay *= (0.5 + random.random())
184
+
185
+ return delay