ragfallback 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragfallback/__init__.py +24 -0
- ragfallback/core/__init__.py +6 -0
- ragfallback/core/adaptive_retriever.py +387 -0
- ragfallback/py.typed +2 -0
- ragfallback/strategies/__init__.py +7 -0
- ragfallback/strategies/base.py +37 -0
- ragfallback/strategies/query_variations.py +146 -0
- ragfallback/tracking/__init__.py +7 -0
- ragfallback/tracking/cost_tracker.py +128 -0
- ragfallback/tracking/metrics.py +107 -0
- ragfallback/utils/__init__.py +36 -0
- ragfallback/utils/confidence_scorer.py +138 -0
- ragfallback/utils/embedding_factory.py +119 -0
- ragfallback/utils/llm_factory.py +231 -0
- ragfallback/utils/vector_store_factory.py +163 -0
- ragfallback-0.1.0.dist-info/LICENSE +21 -0
- ragfallback-0.1.0.dist-info/METADATA +878 -0
- ragfallback-0.1.0.dist-info/RECORD +27 -0
- ragfallback-0.1.0.dist-info/WHEEL +5 -0
- ragfallback-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +2 -0
- tests/conftest.py +88 -0
- tests/test_confidence_scorer.py +79 -0
- tests/test_cost_tracker.py +55 -0
- tests/test_integration.py +108 -0
- tests/test_metrics.py +70 -0
- tests/test_query_variations.py +81 -0
ragfallback/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ragfallback - RAG Fallback Strategies Library
|
|
3
|
+
|
|
4
|
+
A production-ready Python library that adds intelligent fallback strategies
|
|
5
|
+
to RAG systems, preventing silent failures and improving answer quality.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__author__ = "Irfan Ali"
|
|
10
|
+
|
|
11
|
+
from ragfallback.core.adaptive_retriever import AdaptiveRAGRetriever, QueryResult
|
|
12
|
+
from ragfallback.strategies.query_variations import QueryVariationsStrategy
|
|
13
|
+
from ragfallback.tracking.cost_tracker import CostTracker, ModelPricing
|
|
14
|
+
from ragfallback.tracking.metrics import MetricsCollector
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"AdaptiveRAGRetriever",
|
|
18
|
+
"QueryResult",
|
|
19
|
+
"QueryVariationsStrategy",
|
|
20
|
+
"CostTracker",
|
|
21
|
+
"ModelPricing",
|
|
22
|
+
"MetricsCollector",
|
|
23
|
+
]
|
|
24
|
+
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""Adaptive RAG Retriever with fallback strategies."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Dict, Any, Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
from langchain_core.vectorstores import VectorStore
|
|
11
|
+
from langchain_core.language_models import BaseLanguageModel
|
|
12
|
+
from langchain_core.embeddings import Embeddings
|
|
13
|
+
from langchain_core.messages import SystemMessage, HumanMessage
|
|
14
|
+
|
|
15
|
+
from ragfallback.strategies.base import FallbackStrategy
|
|
16
|
+
from ragfallback.strategies.query_variations import QueryVariationsStrategy
|
|
17
|
+
from ragfallback.tracking.cost_tracker import CostTracker
|
|
18
|
+
from ragfallback.tracking.metrics import MetricsCollector
|
|
19
|
+
from ragfallback.utils.confidence_scorer import ConfidenceScorer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class QueryResult:
|
|
24
|
+
"""Result of a RAG query with metadata."""
|
|
25
|
+
|
|
26
|
+
answer: str
|
|
27
|
+
source: str
|
|
28
|
+
confidence: float
|
|
29
|
+
attempts: int
|
|
30
|
+
cost: float
|
|
31
|
+
intermediate_steps: Optional[List[Dict]] = None
|
|
32
|
+
|
|
33
|
+
def __repr__(self):
|
|
34
|
+
return (
|
|
35
|
+
f"QueryResult(answer='{self.answer[:50]}...', "
|
|
36
|
+
f"confidence={self.confidence:.2%}, attempts={self.attempts}, "
|
|
37
|
+
f"cost=${self.cost:.4f})"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AdaptiveRAGRetriever:
|
|
42
|
+
"""
|
|
43
|
+
Adaptive RAG Retriever with intelligent fallback strategies.
|
|
44
|
+
|
|
45
|
+
This retriever attempts to answer queries using multiple strategies,
|
|
46
|
+
falling back to alternative approaches when initial attempts fail
|
|
47
|
+
or yield low-confidence results.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
DEFAULT_ANSWER_PROMPT = """You are a helpful assistant that answers questions based on provided documents.
|
|
51
|
+
|
|
52
|
+
Answer the question based on the documents provided. If the answer is not in the documents, respond with "Not found".
|
|
53
|
+
|
|
54
|
+
Return your answer in JSON format: {"answer": "...", "source": "..."}"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
vector_store: VectorStore,
|
|
59
|
+
llm: BaseLanguageModel,
|
|
60
|
+
embedding_model: Embeddings,
|
|
61
|
+
fallback_strategy: str = "query_variations",
|
|
62
|
+
fallback_strategies: Optional[List[FallbackStrategy]] = None,
|
|
63
|
+
max_attempts: int = 3,
|
|
64
|
+
min_confidence: float = 0.7,
|
|
65
|
+
cost_tracker: Optional[CostTracker] = None,
|
|
66
|
+
metrics_collector: Optional[MetricsCollector] = None,
|
|
67
|
+
enable_logging: bool = True,
|
|
68
|
+
answer_prompt_template: Optional[str] = None
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Initialize AdaptiveRAGRetriever.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
vector_store: Vector store instance (Chroma, Qdrant, etc.)
|
|
75
|
+
llm: Language model for query generation and answer synthesis
|
|
76
|
+
embedding_model: Embedding model for semantic search
|
|
77
|
+
fallback_strategy: Name of default fallback strategy
|
|
78
|
+
fallback_strategies: List of custom fallback strategies
|
|
79
|
+
max_attempts: Maximum number of query attempts
|
|
80
|
+
min_confidence: Minimum confidence threshold for success
|
|
81
|
+
cost_tracker: Optional cost tracking instance
|
|
82
|
+
metrics_collector: Optional metrics collection instance
|
|
83
|
+
enable_logging: Enable detailed logging
|
|
84
|
+
answer_prompt_template: Custom prompt template for answer generation
|
|
85
|
+
"""
|
|
86
|
+
self.vector_store = vector_store
|
|
87
|
+
self.llm = llm
|
|
88
|
+
self.embedding_model = embedding_model
|
|
89
|
+
self.max_attempts = max_attempts
|
|
90
|
+
self.min_confidence = min_confidence
|
|
91
|
+
self.cost_tracker = cost_tracker or CostTracker()
|
|
92
|
+
self.metrics_collector = metrics_collector or MetricsCollector()
|
|
93
|
+
self.answer_prompt_template = answer_prompt_template or self.DEFAULT_ANSWER_PROMPT
|
|
94
|
+
self.logger = logging.getLogger(__name__) if enable_logging else None
|
|
95
|
+
|
|
96
|
+
# Setup fallback strategies
|
|
97
|
+
if fallback_strategies:
|
|
98
|
+
self.strategies = fallback_strategies
|
|
99
|
+
else:
|
|
100
|
+
if fallback_strategy == "query_variations":
|
|
101
|
+
self.strategies = [QueryVariationsStrategy()]
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(f"Unknown fallback strategy: {fallback_strategy}")
|
|
104
|
+
|
|
105
|
+
def query_with_fallback(
|
|
106
|
+
self,
|
|
107
|
+
question: str,
|
|
108
|
+
context: Optional[Dict[str, Any]] = None,
|
|
109
|
+
return_intermediate_steps: bool = False,
|
|
110
|
+
enforce_budget: bool = False
|
|
111
|
+
) -> QueryResult:
|
|
112
|
+
"""
|
|
113
|
+
Query with automatic fallback strategies.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
question: The question to answer
|
|
117
|
+
context: Optional context dictionary (e.g., {"company": "Acme"})
|
|
118
|
+
return_intermediate_steps: Return all intermediate attempts
|
|
119
|
+
enforce_budget: Stop if budget exceeded
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
QueryResult with answer, source, confidence, and metadata
|
|
123
|
+
"""
|
|
124
|
+
context = context or {}
|
|
125
|
+
intermediate_steps = []
|
|
126
|
+
total_cost = 0.0
|
|
127
|
+
start_time = time.time()
|
|
128
|
+
|
|
129
|
+
# Try each strategy in order
|
|
130
|
+
for strategy_idx, strategy in enumerate(self.strategies):
|
|
131
|
+
if strategy_idx >= self.max_attempts:
|
|
132
|
+
break
|
|
133
|
+
|
|
134
|
+
# Check budget
|
|
135
|
+
if enforce_budget and self.cost_tracker.budget_exceeded():
|
|
136
|
+
if self.logger:
|
|
137
|
+
self.logger.warning("Budget exceeded, stopping fallback attempts")
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
# Generate queries using strategy
|
|
141
|
+
queries = strategy.generate_queries(
|
|
142
|
+
original_query=question,
|
|
143
|
+
context=context,
|
|
144
|
+
attempt=strategy_idx + 1,
|
|
145
|
+
llm=self.llm
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Try each query variation
|
|
149
|
+
for query_idx, query in enumerate(queries):
|
|
150
|
+
attempt_num = strategy_idx * len(queries) + query_idx + 1
|
|
151
|
+
|
|
152
|
+
if attempt_num > self.max_attempts:
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
if self.logger:
|
|
156
|
+
self.logger.info(f"Attempt {attempt_num}/{self.max_attempts}: {query[:100]}...")
|
|
157
|
+
|
|
158
|
+
# Retrieve documents
|
|
159
|
+
docs = self._retrieve_documents(query, context)
|
|
160
|
+
|
|
161
|
+
if not docs:
|
|
162
|
+
if self.logger:
|
|
163
|
+
self.logger.warning(f"No documents found for query: {query}")
|
|
164
|
+
intermediate_steps.append({
|
|
165
|
+
"attempt": attempt_num,
|
|
166
|
+
"query": query,
|
|
167
|
+
"documents": 0,
|
|
168
|
+
"confidence": 0.0,
|
|
169
|
+
"cost": 0.0
|
|
170
|
+
})
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
# Generate answer
|
|
174
|
+
answer, source, confidence, cost = self._generate_answer(
|
|
175
|
+
question=question,
|
|
176
|
+
query=query,
|
|
177
|
+
documents=docs,
|
|
178
|
+
context=context
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
total_cost += cost
|
|
182
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
183
|
+
|
|
184
|
+
# Track attempt
|
|
185
|
+
step_data = {
|
|
186
|
+
"attempt": attempt_num,
|
|
187
|
+
"query": query,
|
|
188
|
+
"documents": len(docs),
|
|
189
|
+
"answer": answer,
|
|
190
|
+
"source": source,
|
|
191
|
+
"confidence": confidence,
|
|
192
|
+
"cost": cost
|
|
193
|
+
}
|
|
194
|
+
intermediate_steps.append(step_data)
|
|
195
|
+
|
|
196
|
+
# Check if we have a good answer
|
|
197
|
+
if confidence >= self.min_confidence and answer.lower() not in ["x", "not found", "n/a", "unknown"]:
|
|
198
|
+
if self.logger:
|
|
199
|
+
self.logger.info(f"✅ Success on attempt {attempt_num} (confidence: {confidence:.2%})")
|
|
200
|
+
|
|
201
|
+
# Update metrics
|
|
202
|
+
self.metrics_collector.record_success(
|
|
203
|
+
attempts=attempt_num,
|
|
204
|
+
confidence=confidence,
|
|
205
|
+
cost=total_cost,
|
|
206
|
+
latency_ms=latency_ms,
|
|
207
|
+
strategy_used=strategy.get_name()
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return QueryResult(
|
|
211
|
+
answer=answer,
|
|
212
|
+
source=source,
|
|
213
|
+
confidence=confidence,
|
|
214
|
+
attempts=attempt_num,
|
|
215
|
+
cost=total_cost,
|
|
216
|
+
intermediate_steps=intermediate_steps if return_intermediate_steps else None
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# All attempts failed
|
|
220
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
221
|
+
|
|
222
|
+
if self.logger:
|
|
223
|
+
self.logger.warning(f"⚠️ All {len(intermediate_steps)} attempts failed")
|
|
224
|
+
|
|
225
|
+
# Return best attempt or default
|
|
226
|
+
if intermediate_steps:
|
|
227
|
+
best_attempt = max(intermediate_steps, key=lambda x: x.get("confidence", 0.0))
|
|
228
|
+
best_answer = best_attempt.get("answer", "No answer found")
|
|
229
|
+
best_source = best_attempt.get("source", "")
|
|
230
|
+
best_confidence = best_attempt.get("confidence", 0.0)
|
|
231
|
+
else:
|
|
232
|
+
best_answer = "No answer found"
|
|
233
|
+
best_source = ""
|
|
234
|
+
best_confidence = 0.0
|
|
235
|
+
|
|
236
|
+
self.metrics_collector.record_failure(
|
|
237
|
+
attempts=len(intermediate_steps),
|
|
238
|
+
cost=total_cost,
|
|
239
|
+
latency_ms=latency_ms,
|
|
240
|
+
strategy_used=self.strategies[0].get_name() if self.strategies else "unknown"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return QueryResult(
|
|
244
|
+
answer=best_answer,
|
|
245
|
+
source=best_source,
|
|
246
|
+
confidence=best_confidence,
|
|
247
|
+
attempts=len(intermediate_steps) or 1,
|
|
248
|
+
cost=total_cost,
|
|
249
|
+
intermediate_steps=intermediate_steps if return_intermediate_steps else None
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def _retrieve_documents(
|
|
253
|
+
self,
|
|
254
|
+
query: str,
|
|
255
|
+
context: Dict[str, Any]
|
|
256
|
+
) -> List:
|
|
257
|
+
"""Retrieve documents from vector store."""
|
|
258
|
+
try:
|
|
259
|
+
# Apply context filters if supported
|
|
260
|
+
search_kwargs = self._build_search_kwargs(context)
|
|
261
|
+
|
|
262
|
+
retriever = self.vector_store.as_retriever(
|
|
263
|
+
search_kwargs=search_kwargs
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
docs = retriever.get_relevant_documents(query)
|
|
267
|
+
return docs
|
|
268
|
+
except Exception as e:
|
|
269
|
+
if self.logger:
|
|
270
|
+
self.logger.error(f"Error retrieving documents: {e}")
|
|
271
|
+
return []
|
|
272
|
+
|
|
273
|
+
def _generate_answer(
|
|
274
|
+
self,
|
|
275
|
+
question: str,
|
|
276
|
+
query: str,
|
|
277
|
+
documents: List,
|
|
278
|
+
context: Dict[str, Any]
|
|
279
|
+
) -> Tuple[str, str, float, float]:
|
|
280
|
+
"""
|
|
281
|
+
Generate answer from documents.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Tuple of (answer, source, confidence, cost)
|
|
285
|
+
"""
|
|
286
|
+
# Format documents
|
|
287
|
+
docs_text = self._format_documents(documents)
|
|
288
|
+
|
|
289
|
+
# Create prompt
|
|
290
|
+
prompt = self._build_answer_prompt(question, docs_text, context)
|
|
291
|
+
|
|
292
|
+
# Generate answer with cost tracking
|
|
293
|
+
with self.cost_tracker.track(operation="answer_generation"):
|
|
294
|
+
messages = [
|
|
295
|
+
SystemMessage(content=self.answer_prompt_template),
|
|
296
|
+
HumanMessage(content=prompt)
|
|
297
|
+
]
|
|
298
|
+
response = self.llm.invoke(messages)
|
|
299
|
+
answer_text = response.content if hasattr(response, 'content') else str(response)
|
|
300
|
+
|
|
301
|
+
# Try to extract token usage if available
|
|
302
|
+
if hasattr(response, 'response_metadata'):
|
|
303
|
+
metadata = response.response_metadata
|
|
304
|
+
if 'token_usage' in metadata:
|
|
305
|
+
usage = metadata['token_usage']
|
|
306
|
+
self.cost_tracker.record_tokens(
|
|
307
|
+
input_tokens=usage.get('prompt_tokens', 0),
|
|
308
|
+
output_tokens=usage.get('completion_tokens', 0),
|
|
309
|
+
model=getattr(self.llm, 'model_name', 'gpt-4')
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Extract answer and source
|
|
313
|
+
answer, source = self._parse_answer(answer_text)
|
|
314
|
+
|
|
315
|
+
# Calculate confidence
|
|
316
|
+
scorer = ConfidenceScorer(llm=self.llm)
|
|
317
|
+
confidence = scorer.score(
|
|
318
|
+
question=question,
|
|
319
|
+
answer=answer,
|
|
320
|
+
documents=documents,
|
|
321
|
+
context=context
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Get cost
|
|
325
|
+
cost = self.cost_tracker.get_last_cost()
|
|
326
|
+
|
|
327
|
+
return answer, source, confidence, cost
|
|
328
|
+
|
|
329
|
+
def _format_documents(self, documents: List) -> str:
|
|
330
|
+
"""Format documents for prompt."""
|
|
331
|
+
formatted = []
|
|
332
|
+
for i, doc in enumerate(documents, 1):
|
|
333
|
+
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
|
|
334
|
+
source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown'
|
|
335
|
+
formatted.append(f"Document {i} (from {source}):\n{content}\n")
|
|
336
|
+
return "\n".join(formatted)
|
|
337
|
+
|
|
338
|
+
def _build_answer_prompt(
|
|
339
|
+
self,
|
|
340
|
+
question: str,
|
|
341
|
+
docs_text: str,
|
|
342
|
+
context: Dict[str, Any]
|
|
343
|
+
) -> str:
|
|
344
|
+
"""Build answer generation prompt."""
|
|
345
|
+
context_str = ""
|
|
346
|
+
if context:
|
|
347
|
+
context_str = f"\n\nContext: {json.dumps(context, indent=2)}\n"
|
|
348
|
+
|
|
349
|
+
return f"""Based on the following documents, answer the question.
|
|
350
|
+
|
|
351
|
+
Question: {question}
|
|
352
|
+
{context_str}
|
|
353
|
+
Documents:
|
|
354
|
+
{docs_text}
|
|
355
|
+
|
|
356
|
+
Provide a clear, concise answer. If the answer is not in the documents, respond with "Not found".
|
|
357
|
+
Return your answer in JSON format: {{"answer": "...", "source": "..."}}"""
|
|
358
|
+
|
|
359
|
+
def _parse_answer(self, answer_text: str) -> Tuple[str, str]:
|
|
360
|
+
"""Parse answer from LLM response."""
|
|
361
|
+
# Try to extract JSON
|
|
362
|
+
json_match = re.search(r'\{[^}]+\}', answer_text, re.DOTALL)
|
|
363
|
+
if json_match:
|
|
364
|
+
try:
|
|
365
|
+
parsed = json.loads(json_match.group())
|
|
366
|
+
return parsed.get("answer", answer_text), parsed.get("source", "")
|
|
367
|
+
except json.JSONDecodeError:
|
|
368
|
+
pass
|
|
369
|
+
|
|
370
|
+
# Fallback: return as-is
|
|
371
|
+
return answer_text, ""
|
|
372
|
+
|
|
373
|
+
def _build_search_kwargs(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
374
|
+
"""Build search kwargs with context filters."""
|
|
375
|
+
kwargs = {"k": 5} # Default top-k
|
|
376
|
+
|
|
377
|
+
# Add context-based filters if vector store supports it
|
|
378
|
+
if hasattr(self.vector_store, 'filter'):
|
|
379
|
+
filters = {}
|
|
380
|
+
for key, value in context.items():
|
|
381
|
+
if key in ["company_key", "unique_id", "filter_id"]:
|
|
382
|
+
filters[key] = value
|
|
383
|
+
if filters:
|
|
384
|
+
kwargs["filter"] = filters
|
|
385
|
+
|
|
386
|
+
return kwargs
|
|
387
|
+
|
ragfallback/py.typed
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Base class for fallback strategies."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Dict, Any, Optional
|
|
5
|
+
|
|
6
|
+
from langchain_core.language_models import BaseLanguageModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FallbackStrategy(ABC):
|
|
10
|
+
"""Base class for all fallback strategies."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def generate_queries(
|
|
14
|
+
self,
|
|
15
|
+
original_query: str,
|
|
16
|
+
context: Dict[str, Any],
|
|
17
|
+
attempt: int,
|
|
18
|
+
llm: BaseLanguageModel
|
|
19
|
+
) -> List[str]:
|
|
20
|
+
"""
|
|
21
|
+
Generate query variations for fallback.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
original_query: The original query string
|
|
25
|
+
context: Context dictionary (e.g., {"company": "Acme"})
|
|
26
|
+
attempt: Current attempt number (1-indexed)
|
|
27
|
+
llm: Language model for query generation
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
List of query strings to try
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def get_name(self) -> str:
|
|
35
|
+
"""Get strategy name."""
|
|
36
|
+
return self.__class__.__name__
|
|
37
|
+
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Query Variations Strategy - LLM-based query rewriting."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import List, Dict, Any, Optional
|
|
6
|
+
|
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
|
8
|
+
from langchain_core.messages import HumanMessage
|
|
9
|
+
|
|
10
|
+
from ragfallback.strategies.base import FallbackStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class QueryVariationsStrategy(FallbackStrategy):
|
|
14
|
+
"""
|
|
15
|
+
Generate query variations using LLM.
|
|
16
|
+
|
|
17
|
+
This strategy uses an LLM to generate alternative formulations
|
|
18
|
+
of the original query, increasing the chances of finding relevant documents.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
DEFAULT_PROMPT_TEMPLATE = """Generate {num_variations} alternative ways to ask this question that might find the answer in documentation.
|
|
22
|
+
|
|
23
|
+
Original question: "{query}"
|
|
24
|
+
{context}
|
|
25
|
+
|
|
26
|
+
Requirements:
|
|
27
|
+
- Use different terminology or phrasing
|
|
28
|
+
- Be more specific or more general as needed
|
|
29
|
+
- Focus on key concepts from the original question
|
|
30
|
+
|
|
31
|
+
Return ONLY a JSON array of strings: ["variation 1", "variation 2", ...]
|
|
32
|
+
Do not include any explanation or markdown formatting."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
num_variations: int = 2,
|
|
37
|
+
include_original: bool = True,
|
|
38
|
+
variation_prompt_template: Optional[str] = None
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize QueryVariationsStrategy.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
num_variations: Number of variations to generate
|
|
45
|
+
include_original: Include original query in results
|
|
46
|
+
variation_prompt_template: Custom prompt template
|
|
47
|
+
"""
|
|
48
|
+
self.num_variations = num_variations
|
|
49
|
+
self.include_original = include_original
|
|
50
|
+
self.prompt_template = variation_prompt_template or self.DEFAULT_PROMPT_TEMPLATE
|
|
51
|
+
self.logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
def generate_queries(
|
|
54
|
+
self,
|
|
55
|
+
original_query: str,
|
|
56
|
+
context: Dict[str, Any],
|
|
57
|
+
attempt: int,
|
|
58
|
+
llm: BaseLanguageModel
|
|
59
|
+
) -> List[str]:
|
|
60
|
+
"""
|
|
61
|
+
Generate query variations.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
original_query: Original query string
|
|
65
|
+
context: Context dictionary
|
|
66
|
+
attempt: Current attempt number
|
|
67
|
+
llm: Language model for generation
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of query strings
|
|
71
|
+
"""
|
|
72
|
+
queries = []
|
|
73
|
+
|
|
74
|
+
# Include original query first
|
|
75
|
+
if self.include_original and attempt == 1:
|
|
76
|
+
queries.append(original_query)
|
|
77
|
+
|
|
78
|
+
# Generate variations
|
|
79
|
+
if self.num_variations > 0:
|
|
80
|
+
variations = self._generate_variations(
|
|
81
|
+
original_query=original_query,
|
|
82
|
+
context=context,
|
|
83
|
+
num_variations=self.num_variations,
|
|
84
|
+
llm=llm
|
|
85
|
+
)
|
|
86
|
+
queries.extend(variations)
|
|
87
|
+
|
|
88
|
+
self.logger.info(f"Generated {len(queries)} query variations")
|
|
89
|
+
return queries
|
|
90
|
+
|
|
91
|
+
def _generate_variations(
|
|
92
|
+
self,
|
|
93
|
+
original_query: str,
|
|
94
|
+
context: Dict[str, Any],
|
|
95
|
+
num_variations: int,
|
|
96
|
+
llm: BaseLanguageModel
|
|
97
|
+
) -> List[str]:
|
|
98
|
+
"""Generate query variations using LLM."""
|
|
99
|
+
# Build context string
|
|
100
|
+
context_str = ""
|
|
101
|
+
if context:
|
|
102
|
+
context_str = f"\n\nContext: {json.dumps(context, indent=2)}"
|
|
103
|
+
|
|
104
|
+
# Build prompt
|
|
105
|
+
prompt = self.prompt_template.format(
|
|
106
|
+
query=original_query,
|
|
107
|
+
num_variations=num_variations,
|
|
108
|
+
context=context_str
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
# Generate variations
|
|
113
|
+
response = llm.invoke([HumanMessage(content=prompt)])
|
|
114
|
+
response_text = response.content if hasattr(response, 'content') else str(response)
|
|
115
|
+
|
|
116
|
+
# Parse JSON array
|
|
117
|
+
variations = json.loads(response_text)
|
|
118
|
+
if isinstance(variations, list):
|
|
119
|
+
return variations[:num_variations]
|
|
120
|
+
else:
|
|
121
|
+
self.logger.warning(f"Expected list, got {type(variations)}")
|
|
122
|
+
return []
|
|
123
|
+
except json.JSONDecodeError as e:
|
|
124
|
+
self.logger.error(f"Error parsing JSON from LLM response: {e}")
|
|
125
|
+
# Try to extract array from markdown or text
|
|
126
|
+
return self._extract_variations_from_text(response_text)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
self.logger.error(f"Error generating query variations: {e}")
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
def _extract_variations_from_text(self, text: str) -> List[str]:
|
|
132
|
+
"""Extract variations from text if JSON parsing fails."""
|
|
133
|
+
import re
|
|
134
|
+
# Try to find array-like patterns
|
|
135
|
+
array_match = re.search(r'\[(.*?)\]', text, re.DOTALL)
|
|
136
|
+
if array_match:
|
|
137
|
+
try:
|
|
138
|
+
variations = json.loads(f"[{array_match.group(1)}]")
|
|
139
|
+
return variations if isinstance(variations, list) else []
|
|
140
|
+
except:
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
# Fallback: split by lines and clean
|
|
144
|
+
lines = [line.strip().strip('"').strip("'") for line in text.split('\n')]
|
|
145
|
+
return [line for line in lines if line and len(line) > 10][:self.num_variations]
|
|
146
|
+
|