ragfallback 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ """
2
+ ragfallback - RAG Fallback Strategies Library
3
+
4
+ A production-ready Python library that adds intelligent fallback strategies
5
+ to RAG systems, preventing silent failures and improving answer quality.
6
+ """
7
+
8
+ __version__ = "0.1.0"
9
+ __author__ = "Irfan Ali"
10
+
11
+ from ragfallback.core.adaptive_retriever import AdaptiveRAGRetriever, QueryResult
12
+ from ragfallback.strategies.query_variations import QueryVariationsStrategy
13
+ from ragfallback.tracking.cost_tracker import CostTracker, ModelPricing
14
+ from ragfallback.tracking.metrics import MetricsCollector
15
+
16
+ __all__ = [
17
+ "AdaptiveRAGRetriever",
18
+ "QueryResult",
19
+ "QueryVariationsStrategy",
20
+ "CostTracker",
21
+ "ModelPricing",
22
+ "MetricsCollector",
23
+ ]
24
+
@@ -0,0 +1,6 @@
1
+ """Core retriever components."""
2
+
3
+ from ragfallback.core.adaptive_retriever import AdaptiveRAGRetriever, QueryResult
4
+
5
+ __all__ = ["AdaptiveRAGRetriever", "QueryResult"]
6
+
@@ -0,0 +1,387 @@
1
+ """Adaptive RAG Retriever with fallback strategies."""
2
+
3
+ from typing import List, Optional, Dict, Any, Tuple
4
+ from dataclasses import dataclass
5
+ import logging
6
+ import time
7
+ import json
8
+ import re
9
+
10
+ from langchain_core.vectorstores import VectorStore
11
+ from langchain_core.language_models import BaseLanguageModel
12
+ from langchain_core.embeddings import Embeddings
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+
15
+ from ragfallback.strategies.base import FallbackStrategy
16
+ from ragfallback.strategies.query_variations import QueryVariationsStrategy
17
+ from ragfallback.tracking.cost_tracker import CostTracker
18
+ from ragfallback.tracking.metrics import MetricsCollector
19
+ from ragfallback.utils.confidence_scorer import ConfidenceScorer
20
+
21
+
22
+ @dataclass
23
+ class QueryResult:
24
+ """Result of a RAG query with metadata."""
25
+
26
+ answer: str
27
+ source: str
28
+ confidence: float
29
+ attempts: int
30
+ cost: float
31
+ intermediate_steps: Optional[List[Dict]] = None
32
+
33
+ def __repr__(self):
34
+ return (
35
+ f"QueryResult(answer='{self.answer[:50]}...', "
36
+ f"confidence={self.confidence:.2%}, attempts={self.attempts}, "
37
+ f"cost=${self.cost:.4f})"
38
+ )
39
+
40
+
41
+ class AdaptiveRAGRetriever:
42
+ """
43
+ Adaptive RAG Retriever with intelligent fallback strategies.
44
+
45
+ This retriever attempts to answer queries using multiple strategies,
46
+ falling back to alternative approaches when initial attempts fail
47
+ or yield low-confidence results.
48
+ """
49
+
50
+ DEFAULT_ANSWER_PROMPT = """You are a helpful assistant that answers questions based on provided documents.
51
+
52
+ Answer the question based on the documents provided. If the answer is not in the documents, respond with "Not found".
53
+
54
+ Return your answer in JSON format: {"answer": "...", "source": "..."}"""
55
+
56
+ def __init__(
57
+ self,
58
+ vector_store: VectorStore,
59
+ llm: BaseLanguageModel,
60
+ embedding_model: Embeddings,
61
+ fallback_strategy: str = "query_variations",
62
+ fallback_strategies: Optional[List[FallbackStrategy]] = None,
63
+ max_attempts: int = 3,
64
+ min_confidence: float = 0.7,
65
+ cost_tracker: Optional[CostTracker] = None,
66
+ metrics_collector: Optional[MetricsCollector] = None,
67
+ enable_logging: bool = True,
68
+ answer_prompt_template: Optional[str] = None
69
+ ):
70
+ """
71
+ Initialize AdaptiveRAGRetriever.
72
+
73
+ Args:
74
+ vector_store: Vector store instance (Chroma, Qdrant, etc.)
75
+ llm: Language model for query generation and answer synthesis
76
+ embedding_model: Embedding model for semantic search
77
+ fallback_strategy: Name of default fallback strategy
78
+ fallback_strategies: List of custom fallback strategies
79
+ max_attempts: Maximum number of query attempts
80
+ min_confidence: Minimum confidence threshold for success
81
+ cost_tracker: Optional cost tracking instance
82
+ metrics_collector: Optional metrics collection instance
83
+ enable_logging: Enable detailed logging
84
+ answer_prompt_template: Custom prompt template for answer generation
85
+ """
86
+ self.vector_store = vector_store
87
+ self.llm = llm
88
+ self.embedding_model = embedding_model
89
+ self.max_attempts = max_attempts
90
+ self.min_confidence = min_confidence
91
+ self.cost_tracker = cost_tracker or CostTracker()
92
+ self.metrics_collector = metrics_collector or MetricsCollector()
93
+ self.answer_prompt_template = answer_prompt_template or self.DEFAULT_ANSWER_PROMPT
94
+ self.logger = logging.getLogger(__name__) if enable_logging else None
95
+
96
+ # Setup fallback strategies
97
+ if fallback_strategies:
98
+ self.strategies = fallback_strategies
99
+ else:
100
+ if fallback_strategy == "query_variations":
101
+ self.strategies = [QueryVariationsStrategy()]
102
+ else:
103
+ raise ValueError(f"Unknown fallback strategy: {fallback_strategy}")
104
+
105
+ def query_with_fallback(
106
+ self,
107
+ question: str,
108
+ context: Optional[Dict[str, Any]] = None,
109
+ return_intermediate_steps: bool = False,
110
+ enforce_budget: bool = False
111
+ ) -> QueryResult:
112
+ """
113
+ Query with automatic fallback strategies.
114
+
115
+ Args:
116
+ question: The question to answer
117
+ context: Optional context dictionary (e.g., {"company": "Acme"})
118
+ return_intermediate_steps: Return all intermediate attempts
119
+ enforce_budget: Stop if budget exceeded
120
+
121
+ Returns:
122
+ QueryResult with answer, source, confidence, and metadata
123
+ """
124
+ context = context or {}
125
+ intermediate_steps = []
126
+ total_cost = 0.0
127
+ start_time = time.time()
128
+
129
+ # Try each strategy in order
130
+ for strategy_idx, strategy in enumerate(self.strategies):
131
+ if strategy_idx >= self.max_attempts:
132
+ break
133
+
134
+ # Check budget
135
+ if enforce_budget and self.cost_tracker.budget_exceeded():
136
+ if self.logger:
137
+ self.logger.warning("Budget exceeded, stopping fallback attempts")
138
+ break
139
+
140
+ # Generate queries using strategy
141
+ queries = strategy.generate_queries(
142
+ original_query=question,
143
+ context=context,
144
+ attempt=strategy_idx + 1,
145
+ llm=self.llm
146
+ )
147
+
148
+ # Try each query variation
149
+ for query_idx, query in enumerate(queries):
150
+ attempt_num = strategy_idx * len(queries) + query_idx + 1
151
+
152
+ if attempt_num > self.max_attempts:
153
+ break
154
+
155
+ if self.logger:
156
+ self.logger.info(f"Attempt {attempt_num}/{self.max_attempts}: {query[:100]}...")
157
+
158
+ # Retrieve documents
159
+ docs = self._retrieve_documents(query, context)
160
+
161
+ if not docs:
162
+ if self.logger:
163
+ self.logger.warning(f"No documents found for query: {query}")
164
+ intermediate_steps.append({
165
+ "attempt": attempt_num,
166
+ "query": query,
167
+ "documents": 0,
168
+ "confidence": 0.0,
169
+ "cost": 0.0
170
+ })
171
+ continue
172
+
173
+ # Generate answer
174
+ answer, source, confidence, cost = self._generate_answer(
175
+ question=question,
176
+ query=query,
177
+ documents=docs,
178
+ context=context
179
+ )
180
+
181
+ total_cost += cost
182
+ latency_ms = (time.time() - start_time) * 1000
183
+
184
+ # Track attempt
185
+ step_data = {
186
+ "attempt": attempt_num,
187
+ "query": query,
188
+ "documents": len(docs),
189
+ "answer": answer,
190
+ "source": source,
191
+ "confidence": confidence,
192
+ "cost": cost
193
+ }
194
+ intermediate_steps.append(step_data)
195
+
196
+ # Check if we have a good answer
197
+ if confidence >= self.min_confidence and answer.lower() not in ["x", "not found", "n/a", "unknown"]:
198
+ if self.logger:
199
+ self.logger.info(f"✅ Success on attempt {attempt_num} (confidence: {confidence:.2%})")
200
+
201
+ # Update metrics
202
+ self.metrics_collector.record_success(
203
+ attempts=attempt_num,
204
+ confidence=confidence,
205
+ cost=total_cost,
206
+ latency_ms=latency_ms,
207
+ strategy_used=strategy.get_name()
208
+ )
209
+
210
+ return QueryResult(
211
+ answer=answer,
212
+ source=source,
213
+ confidence=confidence,
214
+ attempts=attempt_num,
215
+ cost=total_cost,
216
+ intermediate_steps=intermediate_steps if return_intermediate_steps else None
217
+ )
218
+
219
+ # All attempts failed
220
+ latency_ms = (time.time() - start_time) * 1000
221
+
222
+ if self.logger:
223
+ self.logger.warning(f"⚠️ All {len(intermediate_steps)} attempts failed")
224
+
225
+ # Return best attempt or default
226
+ if intermediate_steps:
227
+ best_attempt = max(intermediate_steps, key=lambda x: x.get("confidence", 0.0))
228
+ best_answer = best_attempt.get("answer", "No answer found")
229
+ best_source = best_attempt.get("source", "")
230
+ best_confidence = best_attempt.get("confidence", 0.0)
231
+ else:
232
+ best_answer = "No answer found"
233
+ best_source = ""
234
+ best_confidence = 0.0
235
+
236
+ self.metrics_collector.record_failure(
237
+ attempts=len(intermediate_steps),
238
+ cost=total_cost,
239
+ latency_ms=latency_ms,
240
+ strategy_used=self.strategies[0].get_name() if self.strategies else "unknown"
241
+ )
242
+
243
+ return QueryResult(
244
+ answer=best_answer,
245
+ source=best_source,
246
+ confidence=best_confidence,
247
+ attempts=len(intermediate_steps) or 1,
248
+ cost=total_cost,
249
+ intermediate_steps=intermediate_steps if return_intermediate_steps else None
250
+ )
251
+
252
+ def _retrieve_documents(
253
+ self,
254
+ query: str,
255
+ context: Dict[str, Any]
256
+ ) -> List:
257
+ """Retrieve documents from vector store."""
258
+ try:
259
+ # Apply context filters if supported
260
+ search_kwargs = self._build_search_kwargs(context)
261
+
262
+ retriever = self.vector_store.as_retriever(
263
+ search_kwargs=search_kwargs
264
+ )
265
+
266
+ docs = retriever.get_relevant_documents(query)
267
+ return docs
268
+ except Exception as e:
269
+ if self.logger:
270
+ self.logger.error(f"Error retrieving documents: {e}")
271
+ return []
272
+
273
+ def _generate_answer(
274
+ self,
275
+ question: str,
276
+ query: str,
277
+ documents: List,
278
+ context: Dict[str, Any]
279
+ ) -> Tuple[str, str, float, float]:
280
+ """
281
+ Generate answer from documents.
282
+
283
+ Returns:
284
+ Tuple of (answer, source, confidence, cost)
285
+ """
286
+ # Format documents
287
+ docs_text = self._format_documents(documents)
288
+
289
+ # Create prompt
290
+ prompt = self._build_answer_prompt(question, docs_text, context)
291
+
292
+ # Generate answer with cost tracking
293
+ with self.cost_tracker.track(operation="answer_generation"):
294
+ messages = [
295
+ SystemMessage(content=self.answer_prompt_template),
296
+ HumanMessage(content=prompt)
297
+ ]
298
+ response = self.llm.invoke(messages)
299
+ answer_text = response.content if hasattr(response, 'content') else str(response)
300
+
301
+ # Try to extract token usage if available
302
+ if hasattr(response, 'response_metadata'):
303
+ metadata = response.response_metadata
304
+ if 'token_usage' in metadata:
305
+ usage = metadata['token_usage']
306
+ self.cost_tracker.record_tokens(
307
+ input_tokens=usage.get('prompt_tokens', 0),
308
+ output_tokens=usage.get('completion_tokens', 0),
309
+ model=getattr(self.llm, 'model_name', 'gpt-4')
310
+ )
311
+
312
+ # Extract answer and source
313
+ answer, source = self._parse_answer(answer_text)
314
+
315
+ # Calculate confidence
316
+ scorer = ConfidenceScorer(llm=self.llm)
317
+ confidence = scorer.score(
318
+ question=question,
319
+ answer=answer,
320
+ documents=documents,
321
+ context=context
322
+ )
323
+
324
+ # Get cost
325
+ cost = self.cost_tracker.get_last_cost()
326
+
327
+ return answer, source, confidence, cost
328
+
329
+ def _format_documents(self, documents: List) -> str:
330
+ """Format documents for prompt."""
331
+ formatted = []
332
+ for i, doc in enumerate(documents, 1):
333
+ content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
334
+ source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown'
335
+ formatted.append(f"Document {i} (from {source}):\n{content}\n")
336
+ return "\n".join(formatted)
337
+
338
+ def _build_answer_prompt(
339
+ self,
340
+ question: str,
341
+ docs_text: str,
342
+ context: Dict[str, Any]
343
+ ) -> str:
344
+ """Build answer generation prompt."""
345
+ context_str = ""
346
+ if context:
347
+ context_str = f"\n\nContext: {json.dumps(context, indent=2)}\n"
348
+
349
+ return f"""Based on the following documents, answer the question.
350
+
351
+ Question: {question}
352
+ {context_str}
353
+ Documents:
354
+ {docs_text}
355
+
356
+ Provide a clear, concise answer. If the answer is not in the documents, respond with "Not found".
357
+ Return your answer in JSON format: {{"answer": "...", "source": "..."}}"""
358
+
359
+ def _parse_answer(self, answer_text: str) -> Tuple[str, str]:
360
+ """Parse answer from LLM response."""
361
+ # Try to extract JSON
362
+ json_match = re.search(r'\{[^}]+\}', answer_text, re.DOTALL)
363
+ if json_match:
364
+ try:
365
+ parsed = json.loads(json_match.group())
366
+ return parsed.get("answer", answer_text), parsed.get("source", "")
367
+ except json.JSONDecodeError:
368
+ pass
369
+
370
+ # Fallback: return as-is
371
+ return answer_text, ""
372
+
373
+ def _build_search_kwargs(self, context: Dict[str, Any]) -> Dict[str, Any]:
374
+ """Build search kwargs with context filters."""
375
+ kwargs = {"k": 5} # Default top-k
376
+
377
+ # Add context-based filters if vector store supports it
378
+ if hasattr(self.vector_store, 'filter'):
379
+ filters = {}
380
+ for key, value in context.items():
381
+ if key in ["company_key", "unique_id", "filter_id"]:
382
+ filters[key] = value
383
+ if filters:
384
+ kwargs["filter"] = filters
385
+
386
+ return kwargs
387
+
ragfallback/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561
2
+
@@ -0,0 +1,7 @@
1
+ """Fallback strategy implementations."""
2
+
3
+ from ragfallback.strategies.base import FallbackStrategy
4
+ from ragfallback.strategies.query_variations import QueryVariationsStrategy
5
+
6
+ __all__ = ["FallbackStrategy", "QueryVariationsStrategy"]
7
+
@@ -0,0 +1,37 @@
1
+ """Base class for fallback strategies."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Dict, Any, Optional
5
+
6
+ from langchain_core.language_models import BaseLanguageModel
7
+
8
+
9
+ class FallbackStrategy(ABC):
10
+ """Base class for all fallback strategies."""
11
+
12
+ @abstractmethod
13
+ def generate_queries(
14
+ self,
15
+ original_query: str,
16
+ context: Dict[str, Any],
17
+ attempt: int,
18
+ llm: BaseLanguageModel
19
+ ) -> List[str]:
20
+ """
21
+ Generate query variations for fallback.
22
+
23
+ Args:
24
+ original_query: The original query string
25
+ context: Context dictionary (e.g., {"company": "Acme"})
26
+ attempt: Current attempt number (1-indexed)
27
+ llm: Language model for query generation
28
+
29
+ Returns:
30
+ List of query strings to try
31
+ """
32
+ pass
33
+
34
+ def get_name(self) -> str:
35
+ """Get strategy name."""
36
+ return self.__class__.__name__
37
+
@@ -0,0 +1,146 @@
1
+ """Query Variations Strategy - LLM-based query rewriting."""
2
+
3
+ import json
4
+ import logging
5
+ from typing import List, Dict, Any, Optional
6
+
7
+ from langchain_core.language_models import BaseLanguageModel
8
+ from langchain_core.messages import HumanMessage
9
+
10
+ from ragfallback.strategies.base import FallbackStrategy
11
+
12
+
13
+ class QueryVariationsStrategy(FallbackStrategy):
14
+ """
15
+ Generate query variations using LLM.
16
+
17
+ This strategy uses an LLM to generate alternative formulations
18
+ of the original query, increasing the chances of finding relevant documents.
19
+ """
20
+
21
+ DEFAULT_PROMPT_TEMPLATE = """Generate {num_variations} alternative ways to ask this question that might find the answer in documentation.
22
+
23
+ Original question: "{query}"
24
+ {context}
25
+
26
+ Requirements:
27
+ - Use different terminology or phrasing
28
+ - Be more specific or more general as needed
29
+ - Focus on key concepts from the original question
30
+
31
+ Return ONLY a JSON array of strings: ["variation 1", "variation 2", ...]
32
+ Do not include any explanation or markdown formatting."""
33
+
34
+ def __init__(
35
+ self,
36
+ num_variations: int = 2,
37
+ include_original: bool = True,
38
+ variation_prompt_template: Optional[str] = None
39
+ ):
40
+ """
41
+ Initialize QueryVariationsStrategy.
42
+
43
+ Args:
44
+ num_variations: Number of variations to generate
45
+ include_original: Include original query in results
46
+ variation_prompt_template: Custom prompt template
47
+ """
48
+ self.num_variations = num_variations
49
+ self.include_original = include_original
50
+ self.prompt_template = variation_prompt_template or self.DEFAULT_PROMPT_TEMPLATE
51
+ self.logger = logging.getLogger(__name__)
52
+
53
+ def generate_queries(
54
+ self,
55
+ original_query: str,
56
+ context: Dict[str, Any],
57
+ attempt: int,
58
+ llm: BaseLanguageModel
59
+ ) -> List[str]:
60
+ """
61
+ Generate query variations.
62
+
63
+ Args:
64
+ original_query: Original query string
65
+ context: Context dictionary
66
+ attempt: Current attempt number
67
+ llm: Language model for generation
68
+
69
+ Returns:
70
+ List of query strings
71
+ """
72
+ queries = []
73
+
74
+ # Include original query first
75
+ if self.include_original and attempt == 1:
76
+ queries.append(original_query)
77
+
78
+ # Generate variations
79
+ if self.num_variations > 0:
80
+ variations = self._generate_variations(
81
+ original_query=original_query,
82
+ context=context,
83
+ num_variations=self.num_variations,
84
+ llm=llm
85
+ )
86
+ queries.extend(variations)
87
+
88
+ self.logger.info(f"Generated {len(queries)} query variations")
89
+ return queries
90
+
91
+ def _generate_variations(
92
+ self,
93
+ original_query: str,
94
+ context: Dict[str, Any],
95
+ num_variations: int,
96
+ llm: BaseLanguageModel
97
+ ) -> List[str]:
98
+ """Generate query variations using LLM."""
99
+ # Build context string
100
+ context_str = ""
101
+ if context:
102
+ context_str = f"\n\nContext: {json.dumps(context, indent=2)}"
103
+
104
+ # Build prompt
105
+ prompt = self.prompt_template.format(
106
+ query=original_query,
107
+ num_variations=num_variations,
108
+ context=context_str
109
+ )
110
+
111
+ try:
112
+ # Generate variations
113
+ response = llm.invoke([HumanMessage(content=prompt)])
114
+ response_text = response.content if hasattr(response, 'content') else str(response)
115
+
116
+ # Parse JSON array
117
+ variations = json.loads(response_text)
118
+ if isinstance(variations, list):
119
+ return variations[:num_variations]
120
+ else:
121
+ self.logger.warning(f"Expected list, got {type(variations)}")
122
+ return []
123
+ except json.JSONDecodeError as e:
124
+ self.logger.error(f"Error parsing JSON from LLM response: {e}")
125
+ # Try to extract array from markdown or text
126
+ return self._extract_variations_from_text(response_text)
127
+ except Exception as e:
128
+ self.logger.error(f"Error generating query variations: {e}")
129
+ return []
130
+
131
+ def _extract_variations_from_text(self, text: str) -> List[str]:
132
+ """Extract variations from text if JSON parsing fails."""
133
+ import re
134
+ # Try to find array-like patterns
135
+ array_match = re.search(r'\[(.*?)\]', text, re.DOTALL)
136
+ if array_match:
137
+ try:
138
+ variations = json.loads(f"[{array_match.group(1)}]")
139
+ return variations if isinstance(variations, list) else []
140
+ except:
141
+ pass
142
+
143
+ # Fallback: split by lines and clean
144
+ lines = [line.strip().strip('"').strip("'") for line in text.split('\n')]
145
+ return [line for line in lines if line and len(line) > 10][:self.num_variations]
146
+
@@ -0,0 +1,7 @@
1
+ """Cost tracking and metrics collection."""
2
+
3
+ from ragfallback.tracking.cost_tracker import CostTracker, ModelPricing
4
+ from ragfallback.tracking.metrics import MetricsCollector
5
+
6
+ __all__ = ["CostTracker", "ModelPricing", "MetricsCollector"]
7
+