stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
stratifyai/rag.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""RAG (Retrieval-Augmented Generation) pipeline implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a complete RAG pipeline that integrates:
|
|
4
|
+
- Document chunking and indexing
|
|
5
|
+
- Embedding generation
|
|
6
|
+
- Vector database storage
|
|
7
|
+
- Semantic search
|
|
8
|
+
- LLM-based response generation with citations
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import List, Optional, Dict, Any
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .embeddings import EmbeddingProvider, create_embedding_provider
|
|
18
|
+
from .vectordb import VectorDBClient, SearchResult
|
|
19
|
+
from .client import LLMClient
|
|
20
|
+
from .models import ChatRequest, Message
|
|
21
|
+
from .chunking import chunk_content, get_chunk_metadata
|
|
22
|
+
from .exceptions import LLMAbstractionError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class IndexingResult:
|
|
27
|
+
"""Result of indexing operation.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
collection_name: Name of the collection
|
|
31
|
+
num_chunks: Number of chunks indexed
|
|
32
|
+
num_files: Number of files indexed
|
|
33
|
+
total_tokens: Total tokens processed for embeddings
|
|
34
|
+
embedding_cost: Cost of embedding generation
|
|
35
|
+
"""
|
|
36
|
+
collection_name: str
|
|
37
|
+
num_chunks: int
|
|
38
|
+
num_files: int
|
|
39
|
+
total_tokens: int
|
|
40
|
+
embedding_cost: float
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class RAGResponse:
|
|
45
|
+
"""Response from RAG query.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
content: Generated response text
|
|
49
|
+
sources: List of source documents used
|
|
50
|
+
model: LLM model used for generation
|
|
51
|
+
total_cost: Total cost (embeddings + generation)
|
|
52
|
+
num_chunks_retrieved: Number of chunks retrieved
|
|
53
|
+
"""
|
|
54
|
+
content: str
|
|
55
|
+
sources: List[Dict[str, Any]]
|
|
56
|
+
model: str
|
|
57
|
+
total_cost: float
|
|
58
|
+
num_chunks_retrieved: int
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RAGClient:
|
|
62
|
+
"""Client for RAG (Retrieval-Augmented Generation) operations.
|
|
63
|
+
|
|
64
|
+
Provides a complete pipeline for:
|
|
65
|
+
1. Indexing documents into vector database
|
|
66
|
+
2. Querying with semantic search
|
|
67
|
+
3. Generating responses with LLM using retrieved context
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
embedding_provider: Optional[EmbeddingProvider] = None,
|
|
73
|
+
llm_client: Optional[LLMClient] = None,
|
|
74
|
+
persist_directory: Optional[str] = None
|
|
75
|
+
):
|
|
76
|
+
"""Initialize RAG client.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
embedding_provider: Provider for embeddings (default: OpenAI)
|
|
80
|
+
llm_client: LLM client for generation (default: creates new client)
|
|
81
|
+
persist_directory: Directory for vector database (default: ./chroma_db)
|
|
82
|
+
"""
|
|
83
|
+
# Initialize embedding provider
|
|
84
|
+
if embedding_provider is None:
|
|
85
|
+
embedding_provider = create_embedding_provider("openai")
|
|
86
|
+
self.embedding_provider = embedding_provider
|
|
87
|
+
|
|
88
|
+
# Initialize vector database
|
|
89
|
+
self.vectordb = VectorDBClient(
|
|
90
|
+
embedding_provider=embedding_provider,
|
|
91
|
+
persist_directory=persist_directory
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Initialize LLM client
|
|
95
|
+
if llm_client is None:
|
|
96
|
+
llm_client = LLMClient()
|
|
97
|
+
self.llm_client = llm_client
|
|
98
|
+
|
|
99
|
+
async def index_file(
|
|
100
|
+
self,
|
|
101
|
+
file_path: str,
|
|
102
|
+
collection_name: str,
|
|
103
|
+
chunk_size: int = 1000,
|
|
104
|
+
overlap: int = 200
|
|
105
|
+
) -> IndexingResult:
|
|
106
|
+
"""Index a single file into the vector database.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
file_path: Path to file to index
|
|
110
|
+
collection_name: Name of collection to store chunks
|
|
111
|
+
chunk_size: Size of each chunk in characters
|
|
112
|
+
overlap: Overlap between chunks in characters
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
IndexingResult with statistics
|
|
116
|
+
"""
|
|
117
|
+
# Read file
|
|
118
|
+
path = Path(file_path)
|
|
119
|
+
if not path.exists():
|
|
120
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
121
|
+
|
|
122
|
+
content = path.read_text(encoding="utf-8")
|
|
123
|
+
|
|
124
|
+
# Chunk content
|
|
125
|
+
chunks = chunk_content(
|
|
126
|
+
content,
|
|
127
|
+
chunk_size=chunk_size,
|
|
128
|
+
overlap=overlap,
|
|
129
|
+
preserve_boundaries=True
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Create metadata for each chunk
|
|
133
|
+
metadatas = [
|
|
134
|
+
{
|
|
135
|
+
"file": str(path.absolute()),
|
|
136
|
+
"filename": path.name,
|
|
137
|
+
"chunk_idx": i,
|
|
138
|
+
"total_chunks": len(chunks)
|
|
139
|
+
}
|
|
140
|
+
for i in range(len(chunks))
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
# Add to vector database (this generates embeddings automatically)
|
|
144
|
+
# Run in thread pool since ChromaDB is sync
|
|
145
|
+
await asyncio.to_thread(
|
|
146
|
+
self.vectordb.add_documents,
|
|
147
|
+
collection_name=collection_name,
|
|
148
|
+
documents=chunks,
|
|
149
|
+
metadatas=metadatas
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Get embedding stats (estimate)
|
|
153
|
+
# Note: actual costs tracked by embedding provider during add_documents
|
|
154
|
+
chunk_metadata = get_chunk_metadata(chunks)
|
|
155
|
+
|
|
156
|
+
return IndexingResult(
|
|
157
|
+
collection_name=collection_name,
|
|
158
|
+
num_chunks=len(chunks),
|
|
159
|
+
num_files=1,
|
|
160
|
+
total_tokens=chunk_metadata["total_chars"] // 4, # Rough estimate
|
|
161
|
+
embedding_cost=0.0 # Would need to track from vectordb
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
async def index_directory(
|
|
165
|
+
self,
|
|
166
|
+
directory_path: str,
|
|
167
|
+
collection_name: str,
|
|
168
|
+
file_patterns: Optional[List[str]] = None,
|
|
169
|
+
chunk_size: int = 1000,
|
|
170
|
+
overlap: int = 200
|
|
171
|
+
) -> IndexingResult:
|
|
172
|
+
"""Index all files in a directory.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
directory_path: Path to directory
|
|
176
|
+
collection_name: Name of collection to store chunks
|
|
177
|
+
file_patterns: List of file patterns (e.g., ["*.txt", "*.md"])
|
|
178
|
+
chunk_size: Size of each chunk in characters
|
|
179
|
+
overlap: Overlap between chunks in characters
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
IndexingResult with statistics
|
|
183
|
+
"""
|
|
184
|
+
dir_path = Path(directory_path)
|
|
185
|
+
if not dir_path.exists() or not dir_path.is_dir():
|
|
186
|
+
raise ValueError(f"Invalid directory: {directory_path}")
|
|
187
|
+
|
|
188
|
+
# Default patterns
|
|
189
|
+
if file_patterns is None:
|
|
190
|
+
file_patterns = ["*.txt", "*.md", "*.py", "*.js", "*.java", "*.cpp"]
|
|
191
|
+
|
|
192
|
+
# Find all matching files
|
|
193
|
+
files = []
|
|
194
|
+
for pattern in file_patterns:
|
|
195
|
+
files.extend(dir_path.rglob(pattern))
|
|
196
|
+
|
|
197
|
+
if not files:
|
|
198
|
+
raise LLMAbstractionError(
|
|
199
|
+
f"No files found matching patterns {file_patterns} in {directory_path}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Index each file
|
|
203
|
+
total_chunks = 0
|
|
204
|
+
total_tokens = 0
|
|
205
|
+
|
|
206
|
+
for file_path in files:
|
|
207
|
+
try:
|
|
208
|
+
result = await self.index_file(
|
|
209
|
+
file_path=str(file_path),
|
|
210
|
+
collection_name=collection_name,
|
|
211
|
+
chunk_size=chunk_size,
|
|
212
|
+
overlap=overlap
|
|
213
|
+
)
|
|
214
|
+
total_chunks += result.num_chunks
|
|
215
|
+
total_tokens += result.total_tokens
|
|
216
|
+
except Exception as e:
|
|
217
|
+
# Log error but continue with other files
|
|
218
|
+
print(f"Warning: Failed to index {file_path}: {str(e)}")
|
|
219
|
+
|
|
220
|
+
return IndexingResult(
|
|
221
|
+
collection_name=collection_name,
|
|
222
|
+
num_chunks=total_chunks,
|
|
223
|
+
num_files=len(files),
|
|
224
|
+
total_tokens=total_tokens,
|
|
225
|
+
embedding_cost=0.0
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
async def query(
|
|
229
|
+
self,
|
|
230
|
+
collection_name: str,
|
|
231
|
+
query: str,
|
|
232
|
+
provider: str = "openai",
|
|
233
|
+
model: str = "gpt-4o-mini",
|
|
234
|
+
n_results: int = 5,
|
|
235
|
+
include_sources: bool = True
|
|
236
|
+
) -> RAGResponse:
|
|
237
|
+
"""Query the RAG system.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
collection_name: Collection to query
|
|
241
|
+
query: User query
|
|
242
|
+
provider: LLM provider for generation
|
|
243
|
+
model: LLM model for generation
|
|
244
|
+
n_results: Number of chunks to retrieve
|
|
245
|
+
include_sources: Whether to include source citations
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
RAGResponse with generated content and sources
|
|
249
|
+
"""
|
|
250
|
+
# Retrieve relevant chunks (run in thread pool since ChromaDB is sync)
|
|
251
|
+
search_results = await asyncio.to_thread(
|
|
252
|
+
self.vectordb.query,
|
|
253
|
+
collection_name=collection_name,
|
|
254
|
+
query_text=query,
|
|
255
|
+
n_results=n_results
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if not search_results:
|
|
259
|
+
raise LLMAbstractionError(
|
|
260
|
+
f"No results found in collection '{collection_name}'. "
|
|
261
|
+
"Ensure the collection exists and has documents."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Build context from retrieved chunks
|
|
265
|
+
context_parts = []
|
|
266
|
+
sources = []
|
|
267
|
+
|
|
268
|
+
for i, result in enumerate(search_results):
|
|
269
|
+
context_parts.append(f"[Source {i+1}]\n{result.document}")
|
|
270
|
+
sources.append({
|
|
271
|
+
"index": i + 1,
|
|
272
|
+
"file": result.metadata.get("filename", "unknown"),
|
|
273
|
+
"chunk_idx": result.metadata.get("chunk_idx", 0),
|
|
274
|
+
"similarity": 1.0 - result.distance # Convert distance to similarity
|
|
275
|
+
})
|
|
276
|
+
|
|
277
|
+
context = "\n\n".join(context_parts)
|
|
278
|
+
|
|
279
|
+
# Build prompt with context
|
|
280
|
+
prompt = f"""Answer the following question using ONLY the information provided in the sources below.
|
|
281
|
+
If the answer cannot be found in the sources, say "I cannot answer this based on the provided sources."
|
|
282
|
+
|
|
283
|
+
Sources:
|
|
284
|
+
{context}
|
|
285
|
+
|
|
286
|
+
Question: {query}
|
|
287
|
+
|
|
288
|
+
Answer:"""
|
|
289
|
+
|
|
290
|
+
# Generate response
|
|
291
|
+
messages = [Message(role="user", content=prompt)]
|
|
292
|
+
request = ChatRequest(
|
|
293
|
+
model=model,
|
|
294
|
+
messages=messages,
|
|
295
|
+
temperature=0.3 # Lower temperature for more factual responses
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Create client with specified provider for this request
|
|
299
|
+
client = LLMClient(provider=provider)
|
|
300
|
+
response = await client.chat_completion(request=request)
|
|
301
|
+
|
|
302
|
+
# Calculate total cost (embedding + generation)
|
|
303
|
+
total_cost = response.usage.cost_usd if response.usage else 0.0
|
|
304
|
+
|
|
305
|
+
return RAGResponse(
|
|
306
|
+
content=response.content,
|
|
307
|
+
sources=sources if include_sources else [],
|
|
308
|
+
model=response.model,
|
|
309
|
+
total_cost=total_cost,
|
|
310
|
+
num_chunks_retrieved=len(search_results)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def list_collections(self) -> List[str]:
|
|
314
|
+
"""List all available collections.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
List of collection names
|
|
318
|
+
"""
|
|
319
|
+
return self.vectordb.list_collections()
|
|
320
|
+
|
|
321
|
+
def delete_collection(self, collection_name: str) -> None:
|
|
322
|
+
"""Delete a collection.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
collection_name: Collection to delete
|
|
326
|
+
"""
|
|
327
|
+
self.vectordb.delete_collection(collection_name)
|
|
328
|
+
|
|
329
|
+
def get_collection_stats(self, collection_name: str) -> Dict[str, Any]:
|
|
330
|
+
"""Get statistics about a collection.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
collection_name: Collection name
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Dictionary with collection statistics
|
|
337
|
+
"""
|
|
338
|
+
count = self.vectordb.get_collection_count(collection_name)
|
|
339
|
+
|
|
340
|
+
# Get sample documents to extract file info
|
|
341
|
+
sample_docs = self.vectordb.get_documents(
|
|
342
|
+
collection_name=collection_name,
|
|
343
|
+
limit=100
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Extract unique files
|
|
347
|
+
files = set()
|
|
348
|
+
for doc in sample_docs:
|
|
349
|
+
if "filename" in doc["metadata"]:
|
|
350
|
+
files.add(doc["metadata"]["filename"])
|
|
351
|
+
|
|
352
|
+
return {
|
|
353
|
+
"name": collection_name,
|
|
354
|
+
"num_chunks": count,
|
|
355
|
+
"num_files": len(files),
|
|
356
|
+
"sample_files": list(files)[:10] # Show up to 10 files
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
def retrieve_only(
|
|
360
|
+
self,
|
|
361
|
+
collection_name: str,
|
|
362
|
+
query: str,
|
|
363
|
+
n_results: int = 5
|
|
364
|
+
) -> List[SearchResult]:
|
|
365
|
+
"""Retrieve relevant chunks without generating a response.
|
|
366
|
+
|
|
367
|
+
Useful for testing retrieval quality.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
collection_name: Collection to query
|
|
371
|
+
query: Query text
|
|
372
|
+
n_results: Number of results to return
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
List of SearchResult objects
|
|
376
|
+
"""
|
|
377
|
+
return self.vectordb.query(
|
|
378
|
+
collection_name=collection_name,
|
|
379
|
+
query_text=query,
|
|
380
|
+
n_results=n_results
|
|
381
|
+
)
|
stratifyai/retry.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Retry logic with exponential backoff and fallback support."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Callable, List, Optional, Type, Tuple
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
from .exceptions import (
|
|
10
|
+
RateLimitError,
|
|
11
|
+
ProviderAPIError,
|
|
12
|
+
MaxRetriesExceededError,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class RetryConfig:
|
|
18
|
+
"""Configuration for retry behavior."""
|
|
19
|
+
|
|
20
|
+
max_retries: int = 3
|
|
21
|
+
initial_delay: float = 1.0 # seconds
|
|
22
|
+
max_delay: float = 60.0 # seconds
|
|
23
|
+
exponential_base: float = 2.0
|
|
24
|
+
jitter: bool = True
|
|
25
|
+
retry_on_exceptions: Tuple[Type[Exception], ...] = (
|
|
26
|
+
RateLimitError,
|
|
27
|
+
ProviderAPIError,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def with_retry(
|
|
32
|
+
config: Optional[RetryConfig] = None,
|
|
33
|
+
fallback_models: Optional[List[str]] = None,
|
|
34
|
+
fallback_provider: Optional[str] = None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Decorator to add async retry logic with exponential backoff.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config: Retry configuration
|
|
41
|
+
fallback_models: List of fallback models to try if primary fails
|
|
42
|
+
fallback_provider: Fallback provider to use if primary fails
|
|
43
|
+
|
|
44
|
+
Usage:
|
|
45
|
+
@with_retry(config=RetryConfig(max_retries=5))
|
|
46
|
+
async def my_llm_call():
|
|
47
|
+
...
|
|
48
|
+
"""
|
|
49
|
+
if config is None:
|
|
50
|
+
config = RetryConfig()
|
|
51
|
+
|
|
52
|
+
def decorator(func: Callable) -> Callable:
|
|
53
|
+
@wraps(func)
|
|
54
|
+
async def async_wrapper(*args, **kwargs):
|
|
55
|
+
last_exception = None
|
|
56
|
+
|
|
57
|
+
for attempt in range(config.max_retries + 1):
|
|
58
|
+
try:
|
|
59
|
+
return await func(*args, **kwargs)
|
|
60
|
+
except config.retry_on_exceptions as e:
|
|
61
|
+
last_exception = e
|
|
62
|
+
|
|
63
|
+
if attempt == config.max_retries:
|
|
64
|
+
# Try fallbacks if configured
|
|
65
|
+
if fallback_models or fallback_provider:
|
|
66
|
+
logging.info(f"Attempting fallback after {attempt + 1} retries")
|
|
67
|
+
return await _try_fallback_async(
|
|
68
|
+
func,
|
|
69
|
+
args,
|
|
70
|
+
kwargs,
|
|
71
|
+
fallback_models,
|
|
72
|
+
fallback_provider,
|
|
73
|
+
last_exception,
|
|
74
|
+
)
|
|
75
|
+
raise MaxRetriesExceededError(
|
|
76
|
+
f"Max retries ({config.max_retries}) exceeded. Last error: {str(e)}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Calculate delay with exponential backoff
|
|
80
|
+
delay = min(
|
|
81
|
+
config.initial_delay * (config.exponential_base ** attempt),
|
|
82
|
+
config.max_delay
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Add jitter if enabled
|
|
86
|
+
if config.jitter:
|
|
87
|
+
import random
|
|
88
|
+
delay *= (0.5 + random.random())
|
|
89
|
+
|
|
90
|
+
logging.warning(
|
|
91
|
+
f"Retry attempt {attempt + 1}/{config.max_retries} "
|
|
92
|
+
f"after {delay:.2f}s delay. Error: {str(e)}"
|
|
93
|
+
)
|
|
94
|
+
await asyncio.sleep(delay)
|
|
95
|
+
|
|
96
|
+
# Should never reach here, but just in case
|
|
97
|
+
raise last_exception
|
|
98
|
+
|
|
99
|
+
return async_wrapper
|
|
100
|
+
return decorator
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def _try_fallback_async(
|
|
104
|
+
func: Callable,
|
|
105
|
+
args: tuple,
|
|
106
|
+
kwargs: dict,
|
|
107
|
+
fallback_models: Optional[List[str]],
|
|
108
|
+
fallback_provider: Optional[str],
|
|
109
|
+
original_error: Exception,
|
|
110
|
+
):
|
|
111
|
+
"""
|
|
112
|
+
Try fallback models or providers asynchronously.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
func: Original async function
|
|
116
|
+
args: Function args
|
|
117
|
+
kwargs: Function kwargs
|
|
118
|
+
fallback_models: List of fallback model names
|
|
119
|
+
fallback_provider: Fallback provider name
|
|
120
|
+
original_error: Original exception that triggered fallback
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Result from successful fallback
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
MaxRetriesExceededError: If all fallbacks fail
|
|
127
|
+
"""
|
|
128
|
+
# Try fallback models first
|
|
129
|
+
if fallback_models:
|
|
130
|
+
for model in fallback_models:
|
|
131
|
+
try:
|
|
132
|
+
logging.info(f"Trying fallback model: {model}")
|
|
133
|
+
# Update model in kwargs if present
|
|
134
|
+
if 'request' in kwargs and hasattr(kwargs['request'], 'model'):
|
|
135
|
+
kwargs['request'].model = model
|
|
136
|
+
elif 'model' in kwargs:
|
|
137
|
+
kwargs['model'] = model
|
|
138
|
+
return await func(*args, **kwargs)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logging.warning(f"Fallback model {model} failed: {str(e)}")
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
# Try fallback provider
|
|
144
|
+
if fallback_provider:
|
|
145
|
+
try:
|
|
146
|
+
logging.info(f"Trying fallback provider: {fallback_provider}")
|
|
147
|
+
if 'provider' in kwargs:
|
|
148
|
+
kwargs['provider'] = fallback_provider
|
|
149
|
+
return await func(*args, **kwargs)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logging.warning(f"Fallback provider {fallback_provider} failed: {str(e)}")
|
|
152
|
+
|
|
153
|
+
# All fallbacks failed
|
|
154
|
+
raise MaxRetriesExceededError(
|
|
155
|
+
f"All retries and fallbacks failed. Original error: {str(original_error)}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def exponential_backoff(
|
|
160
|
+
attempt: int,
|
|
161
|
+
initial_delay: float = 1.0,
|
|
162
|
+
exponential_base: float = 2.0,
|
|
163
|
+
max_delay: float = 60.0,
|
|
164
|
+
jitter: bool = True,
|
|
165
|
+
) -> float:
|
|
166
|
+
"""
|
|
167
|
+
Calculate exponential backoff delay.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
attempt: Current attempt number (0-indexed)
|
|
171
|
+
initial_delay: Initial delay in seconds
|
|
172
|
+
exponential_base: Base for exponential growth
|
|
173
|
+
max_delay: Maximum delay cap
|
|
174
|
+
jitter: Add random jitter to delay
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Delay in seconds
|
|
178
|
+
"""
|
|
179
|
+
delay = min(initial_delay * (exponential_base ** attempt), max_delay)
|
|
180
|
+
|
|
181
|
+
if jitter:
|
|
182
|
+
import random
|
|
183
|
+
delay *= (0.5 + random.random())
|
|
184
|
+
|
|
185
|
+
return delay
|