kailash 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +33 -1
- kailash/access_control/__init__.py +129 -0
- kailash/access_control/managers.py +461 -0
- kailash/access_control/rule_evaluators.py +467 -0
- kailash/access_control_abac.py +825 -0
- kailash/config/__init__.py +27 -0
- kailash/config/database_config.py +359 -0
- kailash/database/__init__.py +28 -0
- kailash/database/execution_pipeline.py +499 -0
- kailash/middleware/__init__.py +306 -0
- kailash/middleware/auth/__init__.py +33 -0
- kailash/middleware/auth/access_control.py +436 -0
- kailash/middleware/auth/auth_manager.py +422 -0
- kailash/middleware/auth/jwt_auth.py +477 -0
- kailash/middleware/auth/kailash_jwt_auth.py +616 -0
- kailash/middleware/communication/__init__.py +37 -0
- kailash/middleware/communication/ai_chat.py +989 -0
- kailash/middleware/communication/api_gateway.py +802 -0
- kailash/middleware/communication/events.py +470 -0
- kailash/middleware/communication/realtime.py +710 -0
- kailash/middleware/core/__init__.py +21 -0
- kailash/middleware/core/agent_ui.py +890 -0
- kailash/middleware/core/schema.py +643 -0
- kailash/middleware/core/workflows.py +396 -0
- kailash/middleware/database/__init__.py +63 -0
- kailash/middleware/database/base.py +113 -0
- kailash/middleware/database/base_models.py +525 -0
- kailash/middleware/database/enums.py +106 -0
- kailash/middleware/database/migrations.py +12 -0
- kailash/{api/database.py → middleware/database/models.py} +183 -291
- kailash/middleware/database/repositories.py +685 -0
- kailash/middleware/database/session_manager.py +19 -0
- kailash/middleware/mcp/__init__.py +38 -0
- kailash/middleware/mcp/client_integration.py +585 -0
- kailash/middleware/mcp/enhanced_server.py +576 -0
- kailash/nodes/__init__.py +25 -3
- kailash/nodes/admin/__init__.py +35 -0
- kailash/nodes/admin/audit_log.py +794 -0
- kailash/nodes/admin/permission_check.py +864 -0
- kailash/nodes/admin/role_management.py +823 -0
- kailash/nodes/admin/security_event.py +1519 -0
- kailash/nodes/admin/user_management.py +944 -0
- kailash/nodes/ai/a2a.py +24 -7
- kailash/nodes/ai/ai_providers.py +1 -0
- kailash/nodes/ai/embedding_generator.py +11 -11
- kailash/nodes/ai/intelligent_agent_orchestrator.py +99 -11
- kailash/nodes/ai/llm_agent.py +407 -2
- kailash/nodes/ai/self_organizing.py +85 -10
- kailash/nodes/api/auth.py +287 -6
- kailash/nodes/api/rest.py +151 -0
- kailash/nodes/auth/__init__.py +17 -0
- kailash/nodes/auth/directory_integration.py +1228 -0
- kailash/nodes/auth/enterprise_auth_provider.py +1328 -0
- kailash/nodes/auth/mfa.py +2338 -0
- kailash/nodes/auth/risk_assessment.py +872 -0
- kailash/nodes/auth/session_management.py +1093 -0
- kailash/nodes/auth/sso.py +1040 -0
- kailash/nodes/base.py +344 -13
- kailash/nodes/base_cycle_aware.py +4 -2
- kailash/nodes/base_with_acl.py +1 -1
- kailash/nodes/code/python.py +293 -12
- kailash/nodes/compliance/__init__.py +9 -0
- kailash/nodes/compliance/data_retention.py +1888 -0
- kailash/nodes/compliance/gdpr.py +2004 -0
- kailash/nodes/data/__init__.py +22 -2
- kailash/nodes/data/async_connection.py +469 -0
- kailash/nodes/data/async_sql.py +757 -0
- kailash/nodes/data/async_vector.py +598 -0
- kailash/nodes/data/readers.py +767 -0
- kailash/nodes/data/retrieval.py +360 -1
- kailash/nodes/data/sharepoint_graph.py +397 -21
- kailash/nodes/data/sql.py +94 -5
- kailash/nodes/data/streaming.py +68 -8
- kailash/nodes/data/vector_db.py +54 -4
- kailash/nodes/enterprise/__init__.py +13 -0
- kailash/nodes/enterprise/batch_processor.py +741 -0
- kailash/nodes/enterprise/data_lineage.py +497 -0
- kailash/nodes/logic/convergence.py +31 -9
- kailash/nodes/logic/operations.py +14 -3
- kailash/nodes/mixins/__init__.py +8 -0
- kailash/nodes/mixins/event_emitter.py +201 -0
- kailash/nodes/mixins/mcp.py +9 -4
- kailash/nodes/mixins/security.py +165 -0
- kailash/nodes/monitoring/__init__.py +7 -0
- kailash/nodes/monitoring/performance_benchmark.py +2497 -0
- kailash/nodes/rag/__init__.py +284 -0
- kailash/nodes/rag/advanced.py +1615 -0
- kailash/nodes/rag/agentic.py +773 -0
- kailash/nodes/rag/conversational.py +999 -0
- kailash/nodes/rag/evaluation.py +875 -0
- kailash/nodes/rag/federated.py +1188 -0
- kailash/nodes/rag/graph.py +721 -0
- kailash/nodes/rag/multimodal.py +671 -0
- kailash/nodes/rag/optimized.py +933 -0
- kailash/nodes/rag/privacy.py +1059 -0
- kailash/nodes/rag/query_processing.py +1335 -0
- kailash/nodes/rag/realtime.py +764 -0
- kailash/nodes/rag/registry.py +547 -0
- kailash/nodes/rag/router.py +837 -0
- kailash/nodes/rag/similarity.py +1854 -0
- kailash/nodes/rag/strategies.py +566 -0
- kailash/nodes/rag/workflows.py +575 -0
- kailash/nodes/security/__init__.py +19 -0
- kailash/nodes/security/abac_evaluator.py +1411 -0
- kailash/nodes/security/audit_log.py +91 -0
- kailash/nodes/security/behavior_analysis.py +1893 -0
- kailash/nodes/security/credential_manager.py +401 -0
- kailash/nodes/security/rotating_credentials.py +760 -0
- kailash/nodes/security/security_event.py +132 -0
- kailash/nodes/security/threat_detection.py +1103 -0
- kailash/nodes/testing/__init__.py +9 -0
- kailash/nodes/testing/credential_testing.py +499 -0
- kailash/nodes/transform/__init__.py +10 -2
- kailash/nodes/transform/chunkers.py +592 -1
- kailash/nodes/transform/processors.py +484 -14
- kailash/nodes/validation.py +321 -0
- kailash/runtime/access_controlled.py +1 -1
- kailash/runtime/async_local.py +41 -7
- kailash/runtime/docker.py +1 -1
- kailash/runtime/local.py +474 -55
- kailash/runtime/parallel.py +1 -1
- kailash/runtime/parallel_cyclic.py +1 -1
- kailash/runtime/testing.py +210 -2
- kailash/utils/migrations/__init__.py +25 -0
- kailash/utils/migrations/generator.py +433 -0
- kailash/utils/migrations/models.py +231 -0
- kailash/utils/migrations/runner.py +489 -0
- kailash/utils/secure_logging.py +342 -0
- kailash/workflow/__init__.py +16 -0
- kailash/workflow/cyclic_runner.py +3 -4
- kailash/workflow/graph.py +70 -2
- kailash/workflow/resilience.py +249 -0
- kailash/workflow/templates.py +726 -0
- {kailash-0.3.1.dist-info → kailash-0.4.0.dist-info}/METADATA +253 -20
- kailash-0.4.0.dist-info/RECORD +223 -0
- kailash/api/__init__.py +0 -17
- kailash/api/__main__.py +0 -6
- kailash/api/studio_secure.py +0 -893
- kailash/mcp/__main__.py +0 -13
- kailash/mcp/server_new.py +0 -336
- kailash/mcp/servers/__init__.py +0 -12
- kailash-0.3.1.dist-info/RECORD +0 -136
- {kailash-0.3.1.dist-info → kailash-0.4.0.dist-info}/WHEEL +0 -0
- {kailash-0.3.1.dist-info → kailash-0.4.0.dist-info}/entry_points.txt +0 -0
- {kailash-0.3.1.dist-info → kailash-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.3.1.dist-info → kailash-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1615 @@
|
|
1
|
+
"""
|
2
|
+
Advanced RAG Techniques
|
3
|
+
|
4
|
+
Implementation of cutting-edge RAG patterns including:
|
5
|
+
- Self-Correcting RAG with verification
|
6
|
+
- RAG-Fusion with multi-query approach
|
7
|
+
- HyDE (Hypothetical Document Embeddings)
|
8
|
+
- Step-Back prompting for abstract reasoning
|
9
|
+
- Advanced query processing and enhancement
|
10
|
+
|
11
|
+
All techniques use existing Kailash components and WorkflowBuilder patterns.
|
12
|
+
"""
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
import json
|
16
|
+
import logging
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
18
|
+
|
19
|
+
from ...workflow.builder import WorkflowBuilder
|
20
|
+
from ..ai.llm_agent import LLMAgentNode
|
21
|
+
from ..base import Node, NodeParameter, register_node
|
22
|
+
from ..logic.workflow import WorkflowNode
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
# Simple RAGConfig fallback to avoid circular import
|
28
|
+
class RAGConfig:
|
29
|
+
"""Simple RAG configuration"""
|
30
|
+
|
31
|
+
def __init__(self, **kwargs):
|
32
|
+
self.chunk_size = kwargs.get("chunk_size", 1000)
|
33
|
+
self.chunk_overlap = kwargs.get("chunk_overlap", 200)
|
34
|
+
self.embedding_model = kwargs.get("embedding_model", "text-embedding-3-small")
|
35
|
+
self.retrieval_k = kwargs.get("retrieval_k", 5)
|
36
|
+
|
37
|
+
|
38
|
+
def create_hybrid_rag_workflow(config):
|
39
|
+
"""Simple fallback workflow creator"""
|
40
|
+
# In a real implementation, this would create a proper workflow
|
41
|
+
# For now, return a simple mock workflow
|
42
|
+
from ...workflow.graph import Workflow
|
43
|
+
|
44
|
+
return Workflow(name="hybrid_rag_fallback", nodes=[], connections=[])
|
45
|
+
|
46
|
+
|
47
|
+
@register_node()
|
48
|
+
class SelfCorrectingRAGNode(Node):
|
49
|
+
"""
|
50
|
+
Self-Correcting RAG with Verification
|
51
|
+
|
52
|
+
Implements self-verification and iterative correction mechanisms.
|
53
|
+
Uses LLM to assess retrieval quality and refine results automatically.
|
54
|
+
|
55
|
+
Based on 2024 research: Corrective RAG (CRAG) and Self-RAG patterns.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
name: str = "self_correcting_rag",
|
61
|
+
max_corrections: int = 2,
|
62
|
+
confidence_threshold: float = 0.8,
|
63
|
+
verification_model: str = "gpt-4",
|
64
|
+
):
|
65
|
+
self.max_corrections = max_corrections
|
66
|
+
self.confidence_threshold = confidence_threshold
|
67
|
+
self.verification_model = verification_model
|
68
|
+
self.base_rag_workflow = None
|
69
|
+
self.verifier_agent = None
|
70
|
+
super().__init__(name)
|
71
|
+
|
72
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
73
|
+
return {
|
74
|
+
"documents": NodeParameter(
|
75
|
+
name="documents",
|
76
|
+
type=list,
|
77
|
+
required=True,
|
78
|
+
description="Documents for RAG processing",
|
79
|
+
),
|
80
|
+
"query": NodeParameter(
|
81
|
+
name="query",
|
82
|
+
type=str,
|
83
|
+
required=True,
|
84
|
+
description="Query for retrieval and generation",
|
85
|
+
),
|
86
|
+
"config": NodeParameter(
|
87
|
+
name="config",
|
88
|
+
type=dict,
|
89
|
+
required=False,
|
90
|
+
description="RAG configuration parameters",
|
91
|
+
),
|
92
|
+
}
|
93
|
+
|
94
|
+
def run(self, **kwargs) -> Dict[str, Any]:
|
95
|
+
"""Execute self-correcting RAG with iterative refinement"""
|
96
|
+
documents = kwargs.get("documents", [])
|
97
|
+
query = kwargs.get("query", "")
|
98
|
+
config = kwargs.get("config", {})
|
99
|
+
|
100
|
+
# Initialize components
|
101
|
+
self._initialize_components(config)
|
102
|
+
|
103
|
+
# Track correction attempts
|
104
|
+
correction_history = []
|
105
|
+
|
106
|
+
for attempt in range(self.max_corrections + 1):
|
107
|
+
logger.info(f"Self-correcting RAG attempt {attempt + 1}")
|
108
|
+
|
109
|
+
# Perform RAG retrieval and generation
|
110
|
+
rag_result = self._perform_rag(documents, query, attempt)
|
111
|
+
|
112
|
+
# Verify result quality
|
113
|
+
verification = self._verify_result_quality(query, rag_result, documents)
|
114
|
+
|
115
|
+
correction_history.append(
|
116
|
+
{
|
117
|
+
"attempt": attempt + 1,
|
118
|
+
"rag_result": rag_result,
|
119
|
+
"verification": verification,
|
120
|
+
"confidence": verification.get("confidence", 0.0),
|
121
|
+
}
|
122
|
+
)
|
123
|
+
|
124
|
+
# Check if result is satisfactory
|
125
|
+
if verification.get("confidence", 0.0) >= self.confidence_threshold:
|
126
|
+
logger.info(f"Self-correction successful at attempt {attempt + 1}")
|
127
|
+
return self._format_final_result(
|
128
|
+
rag_result, verification, correction_history
|
129
|
+
)
|
130
|
+
|
131
|
+
# If not final attempt, prepare for correction
|
132
|
+
if attempt < self.max_corrections:
|
133
|
+
documents = self._refine_documents(query, documents, verification)
|
134
|
+
query = self._refine_query(query, verification)
|
135
|
+
|
136
|
+
# Return best attempt if all corrections exhausted
|
137
|
+
best_attempt = max(correction_history, key=lambda x: x["confidence"])
|
138
|
+
logger.warning(
|
139
|
+
f"Self-correction completed with best confidence: {best_attempt['confidence']:.3f}"
|
140
|
+
)
|
141
|
+
|
142
|
+
return self._format_final_result(
|
143
|
+
best_attempt["rag_result"], best_attempt["verification"], correction_history
|
144
|
+
)
|
145
|
+
|
146
|
+
def _initialize_components(self, config: Dict[str, Any]):
|
147
|
+
"""Initialize RAG workflow and verification components"""
|
148
|
+
if not self.base_rag_workflow:
|
149
|
+
rag_config = RAGConfig(**config) if config else RAGConfig()
|
150
|
+
self.base_rag_workflow = create_hybrid_rag_workflow(rag_config)
|
151
|
+
|
152
|
+
if not self.verifier_agent:
|
153
|
+
self.verifier_agent = LLMAgentNode(
|
154
|
+
name=f"{self.name}_verifier",
|
155
|
+
model=self.verification_model,
|
156
|
+
provider="openai",
|
157
|
+
system_prompt=self._get_verification_prompt(),
|
158
|
+
)
|
159
|
+
|
160
|
+
def _get_verification_prompt(self) -> str:
|
161
|
+
"""Get system prompt for result verification"""
|
162
|
+
return """You are a RAG quality assessment expert. Your job is to evaluate retrieval and generation quality.
|
163
|
+
|
164
|
+
Analyze the query, retrieved documents, and generated response to assess:
|
165
|
+
|
166
|
+
1. **Retrieval Quality** (0.0-1.0):
|
167
|
+
- Relevance: How well do retrieved docs match the query?
|
168
|
+
- Coverage: Do docs contain information needed to answer?
|
169
|
+
- Diversity: Are different aspects of the query covered?
|
170
|
+
|
171
|
+
2. **Generation Quality** (0.0-1.0):
|
172
|
+
- Faithfulness: Is response consistent with retrieved docs?
|
173
|
+
- Completeness: Does response fully address the query?
|
174
|
+
- Clarity: Is response clear and well-structured?
|
175
|
+
|
176
|
+
3. **Overall Confidence** (0.0-1.0):
|
177
|
+
- Combined assessment of retrieval and generation
|
178
|
+
- Higher confidence = better quality
|
179
|
+
|
180
|
+
4. **Improvement Suggestions**:
|
181
|
+
- Specific actionable recommendations
|
182
|
+
- Query refinements if needed
|
183
|
+
- Document filtering suggestions
|
184
|
+
|
185
|
+
Respond with JSON only:
|
186
|
+
{
|
187
|
+
"retrieval_quality": 0.0-1.0,
|
188
|
+
"generation_quality": 0.0-1.0,
|
189
|
+
"confidence": 0.0-1.0,
|
190
|
+
"issues": ["list of specific issues found"],
|
191
|
+
"suggestions": ["list of improvement recommendations"],
|
192
|
+
"needs_refinement": true/false,
|
193
|
+
"reasoning": "brief explanation of assessment"
|
194
|
+
}"""
|
195
|
+
|
196
|
+
def _perform_rag(
|
197
|
+
self, documents: List[Dict], query: str, attempt: int
|
198
|
+
) -> Dict[str, Any]:
|
199
|
+
"""Perform RAG retrieval and generation"""
|
200
|
+
try:
|
201
|
+
# Add attempt context for potential query modification
|
202
|
+
if attempt > 0:
|
203
|
+
query_with_context = f"[Refinement attempt {attempt}] {query}"
|
204
|
+
else:
|
205
|
+
query_with_context = query
|
206
|
+
|
207
|
+
# Execute base RAG workflow
|
208
|
+
result = self.base_rag_workflow.run(
|
209
|
+
documents=documents, query=query_with_context, operation="retrieve"
|
210
|
+
)
|
211
|
+
|
212
|
+
return {
|
213
|
+
"query": query,
|
214
|
+
"retrieved_documents": result.get("results", []),
|
215
|
+
"scores": result.get("scores", []),
|
216
|
+
"generated_response": self._generate_response(
|
217
|
+
query, result.get("results", [])
|
218
|
+
),
|
219
|
+
"metadata": result.get("metadata", {}),
|
220
|
+
"attempt": attempt + 1,
|
221
|
+
}
|
222
|
+
|
223
|
+
except Exception as e:
|
224
|
+
logger.error(f"RAG execution failed at attempt {attempt + 1}: {e}")
|
225
|
+
return {
|
226
|
+
"query": query,
|
227
|
+
"retrieved_documents": [],
|
228
|
+
"scores": [],
|
229
|
+
"generated_response": f"Error during RAG processing: {str(e)}",
|
230
|
+
"error": str(e),
|
231
|
+
"attempt": attempt + 1,
|
232
|
+
}
|
233
|
+
|
234
|
+
def _generate_response(self, query: str, retrieved_docs: List[Dict]) -> str:
|
235
|
+
"""Generate response from retrieved documents"""
|
236
|
+
if not retrieved_docs:
|
237
|
+
return "No relevant documents found to answer the query."
|
238
|
+
|
239
|
+
# Simple response generation (can be enhanced with dedicated LLM)
|
240
|
+
context = "\n\n".join(
|
241
|
+
[
|
242
|
+
f"Document {i+1}: {doc.get('content', '')[:500]}..."
|
243
|
+
for i, doc in enumerate(retrieved_docs[:3])
|
244
|
+
]
|
245
|
+
)
|
246
|
+
|
247
|
+
return f"Based on the retrieved documents, here is the response to '{query}':\n\n{context}"
|
248
|
+
|
249
|
+
def _verify_result_quality(
|
250
|
+
self, query: str, rag_result: Dict, original_docs: List[Dict]
|
251
|
+
) -> Dict[str, Any]:
|
252
|
+
"""Verify quality of RAG result using LLM"""
|
253
|
+
verification_input = self._format_verification_input(
|
254
|
+
query, rag_result, original_docs
|
255
|
+
)
|
256
|
+
|
257
|
+
try:
|
258
|
+
verification_response = self.verifier_agent.run(
|
259
|
+
messages=[{"role": "user", "content": verification_input}]
|
260
|
+
)
|
261
|
+
|
262
|
+
# Parse LLM response
|
263
|
+
verification = self._parse_verification_response(verification_response)
|
264
|
+
return verification
|
265
|
+
|
266
|
+
except Exception as e:
|
267
|
+
logger.error(f"Verification failed: {e}")
|
268
|
+
return {
|
269
|
+
"retrieval_quality": 0.5,
|
270
|
+
"generation_quality": 0.5,
|
271
|
+
"confidence": 0.5,
|
272
|
+
"issues": [f"Verification error: {str(e)}"],
|
273
|
+
"suggestions": ["Manual review recommended"],
|
274
|
+
"needs_refinement": True,
|
275
|
+
"reasoning": "Automated verification failed",
|
276
|
+
}
|
277
|
+
|
278
|
+
def _format_verification_input(
|
279
|
+
self, query: str, rag_result: Dict, original_docs: List[Dict]
|
280
|
+
) -> str:
|
281
|
+
"""Format input for verification LLM"""
|
282
|
+
retrieved_docs = rag_result.get("retrieved_documents", [])
|
283
|
+
response = rag_result.get("generated_response", "")
|
284
|
+
|
285
|
+
return f"""
|
286
|
+
QUERY: {query}
|
287
|
+
|
288
|
+
RETRIEVED DOCUMENTS ({len(retrieved_docs)} of {len(original_docs)} total):
|
289
|
+
{self._format_documents_for_verification(retrieved_docs)}
|
290
|
+
|
291
|
+
GENERATED RESPONSE:
|
292
|
+
{response}
|
293
|
+
|
294
|
+
RETRIEVAL SCORES: {rag_result.get("scores", [])}
|
295
|
+
|
296
|
+
Assess the quality and provide improvement suggestions:
|
297
|
+
"""
|
298
|
+
|
299
|
+
def _format_documents_for_verification(self, docs: List[Dict]) -> str:
|
300
|
+
"""Format documents for verification prompt"""
|
301
|
+
formatted = []
|
302
|
+
for i, doc in enumerate(docs[:5]): # Limit to 5 docs for prompt length
|
303
|
+
content = doc.get("content", "")[:300] # Truncate for prompt
|
304
|
+
formatted.append(f"Doc {i+1}: {content}...")
|
305
|
+
return "\n\n".join(formatted)
|
306
|
+
|
307
|
+
def _parse_verification_response(self, response: Dict) -> Dict[str, Any]:
|
308
|
+
"""Parse verification response from LLM"""
|
309
|
+
try:
|
310
|
+
content = response.get("content", "")
|
311
|
+
if isinstance(content, list):
|
312
|
+
content = content[0] if content else "{}"
|
313
|
+
|
314
|
+
# Extract JSON from response
|
315
|
+
if "{" in content and "}" in content:
|
316
|
+
json_start = content.find("{")
|
317
|
+
json_end = content.rfind("}") + 1
|
318
|
+
json_str = content[json_start:json_end]
|
319
|
+
verification = json.loads(json_str)
|
320
|
+
|
321
|
+
# Validate required fields
|
322
|
+
required_fields = [
|
323
|
+
"confidence",
|
324
|
+
"retrieval_quality",
|
325
|
+
"generation_quality",
|
326
|
+
]
|
327
|
+
if all(field in verification for field in required_fields):
|
328
|
+
return verification
|
329
|
+
|
330
|
+
# Fallback if parsing fails
|
331
|
+
return self._create_fallback_verification(content)
|
332
|
+
|
333
|
+
except Exception as e:
|
334
|
+
logger.warning(f"Failed to parse verification response: {e}")
|
335
|
+
return self._create_fallback_verification(str(e))
|
336
|
+
|
337
|
+
def _create_fallback_verification(self, content: str) -> Dict[str, Any]:
|
338
|
+
"""Create fallback verification when parsing fails"""
|
339
|
+
# Simple heuristic based on content
|
340
|
+
confidence = (
|
341
|
+
0.6 if "good" in content.lower() or "relevant" in content.lower() else 0.4
|
342
|
+
)
|
343
|
+
|
344
|
+
return {
|
345
|
+
"retrieval_quality": confidence,
|
346
|
+
"generation_quality": confidence,
|
347
|
+
"confidence": confidence,
|
348
|
+
"issues": ["Automated verification parsing failed"],
|
349
|
+
"suggestions": ["Manual review recommended"],
|
350
|
+
"needs_refinement": confidence < self.confidence_threshold,
|
351
|
+
"reasoning": "Fallback assessment due to parsing error",
|
352
|
+
}
|
353
|
+
|
354
|
+
def _refine_documents(
|
355
|
+
self, query: str, documents: List[Dict], verification: Dict
|
356
|
+
) -> List[Dict]:
|
357
|
+
"""Refine document set based on verification feedback"""
|
358
|
+
issues = verification.get("issues", [])
|
359
|
+
suggestions = verification.get("suggestions", [])
|
360
|
+
|
361
|
+
# Simple refinement: filter documents if suggested
|
362
|
+
if any("filter" in suggestion.lower() for suggestion in suggestions):
|
363
|
+
# Keep top 80% of documents by relevance
|
364
|
+
keep_count = max(1, int(len(documents) * 0.8))
|
365
|
+
return documents[:keep_count]
|
366
|
+
|
367
|
+
# If no specific refinement suggested, return original
|
368
|
+
return documents
|
369
|
+
|
370
|
+
def _refine_query(self, query: str, verification: Dict) -> str:
|
371
|
+
"""Refine query based on verification feedback"""
|
372
|
+
suggestions = verification.get("suggestions", [])
|
373
|
+
|
374
|
+
# Simple query refinement based on suggestions
|
375
|
+
for suggestion in suggestions:
|
376
|
+
if "more specific" in suggestion.lower():
|
377
|
+
return f"{query} (please provide specific details)"
|
378
|
+
elif "broader" in suggestion.lower():
|
379
|
+
return f"What are the key aspects of {query}?"
|
380
|
+
|
381
|
+
return query # Return original if no refinement suggested
|
382
|
+
|
383
|
+
def _format_final_result(
|
384
|
+
self, rag_result: Dict, verification: Dict, history: List[Dict]
|
385
|
+
) -> Dict[str, Any]:
|
386
|
+
"""Format final self-correcting RAG result"""
|
387
|
+
return {
|
388
|
+
"query": rag_result.get("query"),
|
389
|
+
"final_response": rag_result.get("generated_response"),
|
390
|
+
"retrieved_documents": rag_result.get("retrieved_documents", []),
|
391
|
+
"scores": rag_result.get("scores", []),
|
392
|
+
"quality_assessment": {
|
393
|
+
"confidence": verification.get("confidence"),
|
394
|
+
"retrieval_quality": verification.get("retrieval_quality"),
|
395
|
+
"generation_quality": verification.get("generation_quality"),
|
396
|
+
"issues_found": verification.get("issues", []),
|
397
|
+
"improvements_made": verification.get("suggestions", []),
|
398
|
+
},
|
399
|
+
"self_correction_metadata": {
|
400
|
+
"total_attempts": len(history),
|
401
|
+
"final_attempt": history[-1]["attempt"] if history else 1,
|
402
|
+
"correction_history": history,
|
403
|
+
"threshold_met": verification.get("confidence", 0.0)
|
404
|
+
>= self.confidence_threshold,
|
405
|
+
},
|
406
|
+
"status": (
|
407
|
+
"corrected"
|
408
|
+
if verification.get("confidence", 0.0) >= self.confidence_threshold
|
409
|
+
else "best_effort"
|
410
|
+
),
|
411
|
+
}
|
412
|
+
|
413
|
+
|
414
|
+
@register_node()
|
415
|
+
class RAGFusionNode(Node):
|
416
|
+
"""
|
417
|
+
RAG-Fusion with Multi-Query Approach
|
418
|
+
|
419
|
+
Generates multiple query variations and fuses results using
|
420
|
+
Reciprocal Rank Fusion (RRF) for improved retrieval performance.
|
421
|
+
|
422
|
+
Provides 15-20% improvement in recall and robustness to query phrasing.
|
423
|
+
|
424
|
+
When to use:
|
425
|
+
- Best for: Ambiguous queries, exploratory search, comprehensive coverage
|
426
|
+
- Not ideal for: Precise technical lookups, when exact matching needed
|
427
|
+
- Performance: ~1 second per query variation
|
428
|
+
- Recall improvement: 20-35% over single query
|
429
|
+
|
430
|
+
Key features:
|
431
|
+
- Automatic query variation generation
|
432
|
+
- Parallel retrieval execution
|
433
|
+
- Reciprocal Rank Fusion
|
434
|
+
- Diversity-aware result selection
|
435
|
+
|
436
|
+
Example:
|
437
|
+
rag_fusion = RAGFusionNode(
|
438
|
+
num_query_variations=5,
|
439
|
+
fusion_method="rrf"
|
440
|
+
)
|
441
|
+
|
442
|
+
# Query: "How to optimize neural networks"
|
443
|
+
# Generates variations:
|
444
|
+
# - "neural network optimization techniques"
|
445
|
+
# - "methods for improving neural network performance"
|
446
|
+
# - "deep learning model optimization strategies"
|
447
|
+
# - "tuning neural network hyperparameters"
|
448
|
+
# - "neural network training optimization"
|
449
|
+
|
450
|
+
result = await rag_fusion.run(
|
451
|
+
documents=documents,
|
452
|
+
query="How to optimize neural networks"
|
453
|
+
)
|
454
|
+
|
455
|
+
Parameters:
|
456
|
+
num_query_variations: Number of query alternatives
|
457
|
+
fusion_method: Result combination strategy (rrf, weighted)
|
458
|
+
query_generator_model: LLM for variation generation
|
459
|
+
diversity_weight: Emphasis on result diversity
|
460
|
+
|
461
|
+
Returns:
|
462
|
+
results: Fused results from all queries
|
463
|
+
query_variations: Generated query alternatives
|
464
|
+
fusion_metadata: Per-query contributions and statistics
|
465
|
+
diversity_score: Result set diversity metric
|
466
|
+
"""
|
467
|
+
|
468
|
+
def __init__(
|
469
|
+
self,
|
470
|
+
name: str = "rag_fusion",
|
471
|
+
num_query_variations: int = 3,
|
472
|
+
fusion_method: str = "rrf",
|
473
|
+
query_generator_model: str = "gpt-4",
|
474
|
+
):
|
475
|
+
self.num_query_variations = num_query_variations
|
476
|
+
self.fusion_method = fusion_method
|
477
|
+
self.query_generator_model = query_generator_model
|
478
|
+
self.query_generator = None
|
479
|
+
self.base_rag_workflow = None
|
480
|
+
super().__init__(name)
|
481
|
+
|
482
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
483
|
+
return {
|
484
|
+
"documents": NodeParameter(
|
485
|
+
name="documents",
|
486
|
+
type=list,
|
487
|
+
required=True,
|
488
|
+
description="Documents for RAG processing",
|
489
|
+
),
|
490
|
+
"query": NodeParameter(
|
491
|
+
name="query",
|
492
|
+
type=str,
|
493
|
+
required=True,
|
494
|
+
description="Original query for fusion processing",
|
495
|
+
),
|
496
|
+
"config": NodeParameter(
|
497
|
+
name="config",
|
498
|
+
type=dict,
|
499
|
+
required=False,
|
500
|
+
description="RAG configuration parameters",
|
501
|
+
),
|
502
|
+
}
|
503
|
+
|
504
|
+
def run(self, **kwargs) -> Dict[str, Any]:
|
505
|
+
"""Execute RAG-Fusion with multi-query approach"""
|
506
|
+
documents = kwargs.get("documents", [])
|
507
|
+
original_query = kwargs.get("query", "")
|
508
|
+
config = kwargs.get("config", {})
|
509
|
+
|
510
|
+
# Initialize components
|
511
|
+
self._initialize_components(config)
|
512
|
+
|
513
|
+
# Generate query variations
|
514
|
+
query_variations = self._generate_query_variations(original_query)
|
515
|
+
all_queries = [original_query] + query_variations
|
516
|
+
|
517
|
+
logger.info(f"RAG-Fusion processing {len(all_queries)} queries")
|
518
|
+
|
519
|
+
# Retrieve for each query
|
520
|
+
all_results = []
|
521
|
+
query_performances = []
|
522
|
+
|
523
|
+
for i, query in enumerate(all_queries):
|
524
|
+
try:
|
525
|
+
result = self._retrieve_for_query(query, documents)
|
526
|
+
all_results.append(result)
|
527
|
+
|
528
|
+
query_performances.append(
|
529
|
+
{
|
530
|
+
"query": query,
|
531
|
+
"is_original": i == 0,
|
532
|
+
"results_count": len(result.get("results", [])),
|
533
|
+
"avg_score": (
|
534
|
+
sum(result.get("scores", []))
|
535
|
+
/ len(result.get("scores", []))
|
536
|
+
if result.get("scores")
|
537
|
+
else 0.0
|
538
|
+
),
|
539
|
+
}
|
540
|
+
)
|
541
|
+
|
542
|
+
except Exception as e:
|
543
|
+
logger.error(f"Query retrieval failed for '{query}': {e}")
|
544
|
+
query_performances.append(
|
545
|
+
{
|
546
|
+
"query": query,
|
547
|
+
"is_original": i == 0,
|
548
|
+
"error": str(e),
|
549
|
+
"results_count": 0,
|
550
|
+
"avg_score": 0.0,
|
551
|
+
}
|
552
|
+
)
|
553
|
+
|
554
|
+
# Fuse results using specified method
|
555
|
+
fused_results = self._fuse_results(all_results, method=self.fusion_method)
|
556
|
+
|
557
|
+
# Generate final response
|
558
|
+
final_response = self._generate_fused_response(original_query, fused_results)
|
559
|
+
|
560
|
+
return {
|
561
|
+
"original_query": original_query,
|
562
|
+
"query_variations": query_variations,
|
563
|
+
"fused_results": fused_results,
|
564
|
+
"final_response": final_response,
|
565
|
+
"fusion_metadata": {
|
566
|
+
"fusion_method": self.fusion_method,
|
567
|
+
"queries_processed": len(all_queries),
|
568
|
+
"query_performances": query_performances,
|
569
|
+
"total_unique_documents": len(
|
570
|
+
set(
|
571
|
+
doc.get("id", doc.get("content", "")[:50])
|
572
|
+
for doc in fused_results.get("documents", [])
|
573
|
+
)
|
574
|
+
),
|
575
|
+
"fusion_score_improvement": self._calculate_fusion_improvement(
|
576
|
+
all_results, fused_results
|
577
|
+
),
|
578
|
+
},
|
579
|
+
}
|
580
|
+
|
581
|
+
def _initialize_components(self, config: Dict[str, Any]):
|
582
|
+
"""Initialize query generator and base RAG workflow"""
|
583
|
+
if not self.query_generator:
|
584
|
+
self.query_generator = LLMAgentNode(
|
585
|
+
name=f"{self.name}_query_generator",
|
586
|
+
model=self.query_generator_model,
|
587
|
+
provider="openai",
|
588
|
+
system_prompt=self._get_query_generation_prompt(),
|
589
|
+
)
|
590
|
+
|
591
|
+
if not self.base_rag_workflow:
|
592
|
+
rag_config = RAGConfig(**config) if config else RAGConfig()
|
593
|
+
self.base_rag_workflow = create_hybrid_rag_workflow(rag_config)
|
594
|
+
|
595
|
+
def _get_query_generation_prompt(self) -> str:
|
596
|
+
"""Get system prompt for query variation generation"""
|
597
|
+
return f"""You are an expert query expansion specialist. Your job is to generate {self.num_query_variations} diverse, high-quality variations of a user query for improved document retrieval.
|
598
|
+
|
599
|
+
Guidelines:
|
600
|
+
1. **Maintain Intent**: All variations must preserve the original query's intent and information need
|
601
|
+
2. **Increase Diversity**: Use different phrasings, terminology, and approaches
|
602
|
+
3. **Enhance Coverage**: Cover different aspects or angles of the query
|
603
|
+
4. **Improve Specificity**: Some variations should be more specific, others more general
|
604
|
+
|
605
|
+
Variation Types to Consider:
|
606
|
+
- **Rephrasing**: Different words with same meaning
|
607
|
+
- **Perspective Shift**: Different viewpoints on the same topic
|
608
|
+
- **Granularity Change**: More specific or more general versions
|
609
|
+
- **Domain Terms**: Use technical vs. common terminology
|
610
|
+
- **Question Types**: Convert statements to questions or vice versa
|
611
|
+
|
612
|
+
Respond with JSON only:
|
613
|
+
{{
|
614
|
+
"variations": [
|
615
|
+
"variation 1",
|
616
|
+
"variation 2",
|
617
|
+
"variation 3"
|
618
|
+
],
|
619
|
+
"reasoning": "brief explanation of variation strategy"
|
620
|
+
}}"""
|
621
|
+
|
622
|
+
def _generate_query_variations(self, original_query: str) -> List[str]:
|
623
|
+
"""Generate query variations using LLM"""
|
624
|
+
try:
|
625
|
+
generation_input = f"""
|
626
|
+
Original Query: {original_query}
|
627
|
+
|
628
|
+
Generate {self.num_query_variations} high-quality variations that will improve retrieval coverage:
|
629
|
+
"""
|
630
|
+
|
631
|
+
response = self.query_generator.run(
|
632
|
+
messages=[{"role": "user", "content": generation_input}]
|
633
|
+
)
|
634
|
+
|
635
|
+
# Parse response
|
636
|
+
variations = self._parse_query_variations(response)
|
637
|
+
logger.info(f"Generated {len(variations)} query variations")
|
638
|
+
return variations
|
639
|
+
|
640
|
+
except Exception as e:
|
641
|
+
logger.error(f"Query variation generation failed: {e}")
|
642
|
+
# Fallback to simple variations
|
643
|
+
return self._generate_fallback_variations(original_query)
|
644
|
+
|
645
|
+
def _parse_query_variations(self, response: Dict) -> List[str]:
|
646
|
+
"""Parse query variations from LLM response"""
|
647
|
+
try:
|
648
|
+
content = response.get("content", "")
|
649
|
+
if isinstance(content, list):
|
650
|
+
content = content[0] if content else "{}"
|
651
|
+
|
652
|
+
# Extract JSON
|
653
|
+
if "{" in content and "}" in content:
|
654
|
+
json_start = content.find("{")
|
655
|
+
json_end = content.rfind("}") + 1
|
656
|
+
json_str = content[json_start:json_end]
|
657
|
+
parsed = json.loads(json_str)
|
658
|
+
|
659
|
+
variations = parsed.get("variations", [])
|
660
|
+
if variations and isinstance(variations, list):
|
661
|
+
return variations[: self.num_query_variations]
|
662
|
+
|
663
|
+
# Fallback parsing
|
664
|
+
return self._extract_variations_from_text(content)
|
665
|
+
|
666
|
+
except Exception as e:
|
667
|
+
logger.warning(f"Failed to parse query variations: {e}")
|
668
|
+
return []
|
669
|
+
|
670
|
+
def _extract_variations_from_text(self, content: str) -> List[str]:
|
671
|
+
"""Extract variations from text when JSON parsing fails"""
|
672
|
+
variations = []
|
673
|
+
lines = content.split("\n")
|
674
|
+
|
675
|
+
for line in lines:
|
676
|
+
line = line.strip()
|
677
|
+
# Look for numbered or bulleted lists
|
678
|
+
if any(
|
679
|
+
line.startswith(prefix) for prefix in ["1.", "2.", "3.", "-", "*", "•"]
|
680
|
+
):
|
681
|
+
# Clean up the line
|
682
|
+
for prefix in ["1.", "2.", "3.", "-", "*", "•", '"', "'"]:
|
683
|
+
line = line.lstrip(prefix).strip()
|
684
|
+
if line and len(line) > 10: # Basic quality filter
|
685
|
+
variations.append(line)
|
686
|
+
|
687
|
+
return variations[: self.num_query_variations]
|
688
|
+
|
689
|
+
def _generate_fallback_variations(self, original_query: str) -> List[str]:
|
690
|
+
"""Generate simple variations when LLM generation fails"""
|
691
|
+
variations = []
|
692
|
+
|
693
|
+
# Simple transformation patterns
|
694
|
+
if "?" not in original_query:
|
695
|
+
variations.append(f"What is {original_query}?")
|
696
|
+
|
697
|
+
if "how" not in original_query.lower():
|
698
|
+
variations.append(f"How does {original_query} work?")
|
699
|
+
|
700
|
+
if len(original_query.split()) > 3:
|
701
|
+
# Extract key terms
|
702
|
+
words = original_query.split()
|
703
|
+
key_terms = words[:3] # First 3 words
|
704
|
+
variations.append(f"Explain {' '.join(key_terms)}")
|
705
|
+
|
706
|
+
return variations[: self.num_query_variations]
|
707
|
+
|
708
|
+
def _retrieve_for_query(self, query: str, documents: List[Dict]) -> Dict[str, Any]:
|
709
|
+
"""Retrieve documents for a single query"""
|
710
|
+
return self.base_rag_workflow.run(
|
711
|
+
documents=documents, query=query, operation="retrieve"
|
712
|
+
)
|
713
|
+
|
714
|
+
def _fuse_results(
|
715
|
+
self, all_results: List[Dict], method: str = "rrf"
|
716
|
+
) -> Dict[str, Any]:
|
717
|
+
"""Fuse results from multiple queries"""
|
718
|
+
if method == "rrf":
|
719
|
+
return self._reciprocal_rank_fusion(all_results)
|
720
|
+
elif method == "weighted":
|
721
|
+
return self._weighted_fusion(all_results)
|
722
|
+
elif method == "simple":
|
723
|
+
return self._simple_concatenation(all_results)
|
724
|
+
else:
|
725
|
+
logger.warning(f"Unknown fusion method: {method}, using RRF")
|
726
|
+
return self._reciprocal_rank_fusion(all_results)
|
727
|
+
|
728
|
+
def _reciprocal_rank_fusion(
|
729
|
+
self, all_results: List[Dict], k: int = 60
|
730
|
+
) -> Dict[str, Any]:
|
731
|
+
"""Implement Reciprocal Rank Fusion (RRF)"""
|
732
|
+
doc_scores = {}
|
733
|
+
doc_contents = {}
|
734
|
+
|
735
|
+
for query_idx, result in enumerate(all_results):
|
736
|
+
documents = result.get("results", [])
|
737
|
+
|
738
|
+
for rank, doc in enumerate(documents):
|
739
|
+
doc_id = doc.get("id", doc.get("content", "")[:50]) # Fallback ID
|
740
|
+
|
741
|
+
# RRF score calculation
|
742
|
+
rrf_score = 1 / (k + rank + 1)
|
743
|
+
|
744
|
+
if doc_id not in doc_scores:
|
745
|
+
doc_scores[doc_id] = {
|
746
|
+
"score": 0.0,
|
747
|
+
"query_sources": [],
|
748
|
+
"original_ranks": [],
|
749
|
+
}
|
750
|
+
doc_contents[doc_id] = doc
|
751
|
+
|
752
|
+
doc_scores[doc_id]["score"] += rrf_score
|
753
|
+
doc_scores[doc_id]["query_sources"].append(query_idx)
|
754
|
+
doc_scores[doc_id]["original_ranks"].append(rank + 1)
|
755
|
+
|
756
|
+
# Sort by fused score
|
757
|
+
sorted_docs = sorted(
|
758
|
+
doc_scores.items(), key=lambda x: x[1]["score"], reverse=True
|
759
|
+
)
|
760
|
+
|
761
|
+
# Format result
|
762
|
+
fused_documents = []
|
763
|
+
fused_scores = []
|
764
|
+
|
765
|
+
for doc_id, score_info in sorted_docs:
|
766
|
+
doc = doc_contents[doc_id]
|
767
|
+
doc["fusion_metadata"] = {
|
768
|
+
"rrf_score": score_info["score"],
|
769
|
+
"query_sources": score_info["query_sources"],
|
770
|
+
"original_ranks": score_info["original_ranks"],
|
771
|
+
"source_diversity": len(set(score_info["query_sources"])),
|
772
|
+
}
|
773
|
+
|
774
|
+
fused_documents.append(doc)
|
775
|
+
fused_scores.append(score_info["score"])
|
776
|
+
|
777
|
+
return {
|
778
|
+
"documents": fused_documents,
|
779
|
+
"scores": fused_scores,
|
780
|
+
"fusion_method": "rrf",
|
781
|
+
"total_unique_docs": len(fused_documents),
|
782
|
+
}
|
783
|
+
|
784
|
+
def _weighted_fusion(self, all_results: List[Dict]) -> Dict[str, Any]:
|
785
|
+
"""Weighted fusion giving higher weight to original query"""
|
786
|
+
weights = [1.0] + [0.7] * (
|
787
|
+
len(all_results) - 1
|
788
|
+
) # Original query gets weight 1.0
|
789
|
+
|
790
|
+
doc_scores = {}
|
791
|
+
doc_contents = {}
|
792
|
+
|
793
|
+
for query_idx, (result, weight) in enumerate(zip(all_results, weights)):
|
794
|
+
documents = result.get("results", [])
|
795
|
+
scores = result.get("scores", [])
|
796
|
+
|
797
|
+
for rank, (doc, score) in enumerate(zip(documents, scores)):
|
798
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
799
|
+
|
800
|
+
weighted_score = score * weight
|
801
|
+
|
802
|
+
if doc_id not in doc_scores:
|
803
|
+
doc_scores[doc_id] = 0.0
|
804
|
+
doc_contents[doc_id] = doc
|
805
|
+
|
806
|
+
doc_scores[doc_id] += weighted_score
|
807
|
+
|
808
|
+
# Sort and format
|
809
|
+
sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
810
|
+
|
811
|
+
return {
|
812
|
+
"documents": [doc_contents[doc_id] for doc_id, _ in sorted_docs],
|
813
|
+
"scores": [score for _, score in sorted_docs],
|
814
|
+
"fusion_method": "weighted",
|
815
|
+
"weights_used": weights,
|
816
|
+
}
|
817
|
+
|
818
|
+
def _simple_concatenation(self, all_results: List[Dict]) -> Dict[str, Any]:
|
819
|
+
"""Simple concatenation with deduplication"""
|
820
|
+
all_docs = []
|
821
|
+
all_scores = []
|
822
|
+
seen_ids = set()
|
823
|
+
|
824
|
+
for result in all_results:
|
825
|
+
documents = result.get("results", [])
|
826
|
+
scores = result.get("scores", [])
|
827
|
+
|
828
|
+
for doc, score in zip(documents, scores):
|
829
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
830
|
+
|
831
|
+
if doc_id not in seen_ids:
|
832
|
+
all_docs.append(doc)
|
833
|
+
all_scores.append(score)
|
834
|
+
seen_ids.add(doc_id)
|
835
|
+
|
836
|
+
return {
|
837
|
+
"documents": all_docs,
|
838
|
+
"scores": all_scores,
|
839
|
+
"fusion_method": "simple_concatenation",
|
840
|
+
}
|
841
|
+
|
842
|
+
def _generate_fused_response(self, original_query: str, fused_results: Dict) -> str:
|
843
|
+
"""Generate final response from fused results"""
|
844
|
+
documents = fused_results.get("documents", [])
|
845
|
+
|
846
|
+
if not documents:
|
847
|
+
return "No relevant documents found after query fusion."
|
848
|
+
|
849
|
+
# Use top documents for response generation
|
850
|
+
top_docs = documents[:5] # Top 5 fused results
|
851
|
+
|
852
|
+
context = "\n\n".join(
|
853
|
+
[
|
854
|
+
f"Source {i+1} (RRF Score: {doc.get('fusion_metadata', {}).get('rrf_score', 0.0):.3f}): "
|
855
|
+
f"{doc.get('content', '')[:400]}..."
|
856
|
+
for i, doc in enumerate(top_docs)
|
857
|
+
]
|
858
|
+
)
|
859
|
+
|
860
|
+
return f"""Based on multiple query perspectives and fused retrieval results for '{original_query}':
|
861
|
+
|
862
|
+
{context}
|
863
|
+
|
864
|
+
[Response generated from {len(documents)} unique documents using {fused_results.get('fusion_method', 'unknown')} fusion]"""
|
865
|
+
|
866
|
+
def _calculate_fusion_improvement(
|
867
|
+
self, individual_results: List[Dict], fused_results: Dict
|
868
|
+
) -> float:
|
869
|
+
"""Calculate improvement provided by fusion"""
|
870
|
+
if not individual_results:
|
871
|
+
return 0.0
|
872
|
+
|
873
|
+
# Compare with best individual result
|
874
|
+
best_individual_count = max(
|
875
|
+
len(result.get("results", [])) for result in individual_results
|
876
|
+
)
|
877
|
+
fused_count = len(fused_results.get("documents", []))
|
878
|
+
|
879
|
+
if best_individual_count == 0:
|
880
|
+
return 0.0
|
881
|
+
|
882
|
+
improvement = (fused_count - best_individual_count) / best_individual_count
|
883
|
+
return round(improvement, 3)
|
884
|
+
|
885
|
+
|
886
|
+
@register_node()
|
887
|
+
class HyDENode(Node):
|
888
|
+
"""
|
889
|
+
HyDE (Hypothetical Document Embeddings)
|
890
|
+
|
891
|
+
Generates hypothetical answers first, then embeds and retrieves
|
892
|
+
based on answer-to-document similarity rather than query-to-document.
|
893
|
+
|
894
|
+
More effective for complex analytical questions where query-document gap is large.
|
895
|
+
|
896
|
+
When to use:
|
897
|
+
- Best for: Complex analytical queries, research questions, abstract concepts
|
898
|
+
- Not ideal for: Factual lookups, keyword-based search
|
899
|
+
- Performance: ~2 seconds (includes hypothesis generation)
|
900
|
+
- Accuracy improvement: 15-30% for complex queries
|
901
|
+
|
902
|
+
Key features:
|
903
|
+
- Hypothetical answer generation
|
904
|
+
- Answer-based similarity matching
|
905
|
+
- Multiple hypothesis support
|
906
|
+
- Zero-shot capability
|
907
|
+
|
908
|
+
Example:
|
909
|
+
hyde = HyDENode(
|
910
|
+
hypothesis_model="gpt-4",
|
911
|
+
use_multiple_hypotheses=True,
|
912
|
+
num_hypotheses=3
|
913
|
+
)
|
914
|
+
|
915
|
+
# Query: "What are the implications of quantum computing for cryptography?"
|
916
|
+
# Generates hypothetical answers:
|
917
|
+
# 1. "Quantum computing poses a significant threat to current..."
|
918
|
+
# 2. "The advent of quantum computers will revolutionize..."
|
919
|
+
# 3. "Cryptographic systems must evolve to be quantum-resistant..."
|
920
|
+
# Then retrieves documents similar to these hypotheses
|
921
|
+
|
922
|
+
result = await hyde.run(
|
923
|
+
documents=documents,
|
924
|
+
query="What are the implications of quantum computing for cryptography?"
|
925
|
+
)
|
926
|
+
|
927
|
+
Parameters:
|
928
|
+
hypothesis_model: LLM for answer generation
|
929
|
+
use_multiple_hypotheses: Generate multiple answers
|
930
|
+
num_hypotheses: Number of hypothetical answers
|
931
|
+
hypothesis_length: Target answer length
|
932
|
+
|
933
|
+
Returns:
|
934
|
+
results: Documents matching hypothetical answers
|
935
|
+
hypotheses_generated: Generated hypothetical answers
|
936
|
+
hyde_metadata: Hypothesis quality and matching stats
|
937
|
+
hypothesis_scores: Individual hypothesis contributions
|
938
|
+
"""
|
939
|
+
|
940
|
+
def __init__(
|
941
|
+
self,
|
942
|
+
name: str = "hyde_rag",
|
943
|
+
hypothesis_model: str = "gpt-4",
|
944
|
+
use_multiple_hypotheses: bool = True,
|
945
|
+
num_hypotheses: int = 2,
|
946
|
+
):
|
947
|
+
self.hypothesis_model = hypothesis_model
|
948
|
+
self.use_multiple_hypotheses = use_multiple_hypotheses
|
949
|
+
self.num_hypotheses = num_hypotheses
|
950
|
+
self.hypothesis_generator = None
|
951
|
+
self.base_rag_workflow = None
|
952
|
+
super().__init__(name)
|
953
|
+
|
954
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
955
|
+
return {
|
956
|
+
"documents": NodeParameter(
|
957
|
+
name="documents",
|
958
|
+
type=list,
|
959
|
+
required=True,
|
960
|
+
description="Documents for HyDE processing",
|
961
|
+
),
|
962
|
+
"query": NodeParameter(
|
963
|
+
name="query",
|
964
|
+
type=str,
|
965
|
+
required=True,
|
966
|
+
description="Query for hypothetical answer generation",
|
967
|
+
),
|
968
|
+
"config": NodeParameter(
|
969
|
+
name="config",
|
970
|
+
type=dict,
|
971
|
+
required=False,
|
972
|
+
description="RAG configuration parameters",
|
973
|
+
),
|
974
|
+
}
|
975
|
+
|
976
|
+
def run(self, **kwargs) -> Dict[str, Any]:
|
977
|
+
"""Execute HyDE (Hypothetical Document Embeddings) approach"""
|
978
|
+
documents = kwargs.get("documents", [])
|
979
|
+
query = kwargs.get("query", "")
|
980
|
+
config = kwargs.get("config", {})
|
981
|
+
|
982
|
+
# Initialize components
|
983
|
+
self._initialize_components(config)
|
984
|
+
|
985
|
+
logger.info(f"HyDE processing query with {len(documents)} documents")
|
986
|
+
|
987
|
+
# Generate hypothetical answer(s)
|
988
|
+
hypotheses = self._generate_hypotheses(query)
|
989
|
+
|
990
|
+
# Retrieve using each hypothesis
|
991
|
+
hypothesis_results = []
|
992
|
+
for i, hypothesis in enumerate(hypotheses):
|
993
|
+
try:
|
994
|
+
result = self._retrieve_with_hypothesis(hypothesis, documents, query)
|
995
|
+
hypothesis_results.append(
|
996
|
+
{
|
997
|
+
"hypothesis": hypothesis,
|
998
|
+
"hypothesis_index": i,
|
999
|
+
"retrieval_result": result,
|
1000
|
+
}
|
1001
|
+
)
|
1002
|
+
except Exception as e:
|
1003
|
+
logger.error(f"HyDE retrieval failed for hypothesis {i}: {e}")
|
1004
|
+
hypothesis_results.append(
|
1005
|
+
{"hypothesis": hypothesis, "hypothesis_index": i, "error": str(e)}
|
1006
|
+
)
|
1007
|
+
|
1008
|
+
# Combine and rank results
|
1009
|
+
combined_results = self._combine_hypothesis_results(hypothesis_results)
|
1010
|
+
|
1011
|
+
# Generate final answer using retrieved documents
|
1012
|
+
final_answer = self._generate_final_answer(query, combined_results, hypotheses)
|
1013
|
+
|
1014
|
+
return {
|
1015
|
+
"original_query": query,
|
1016
|
+
"hypotheses_generated": hypotheses,
|
1017
|
+
"hypothesis_results": hypothesis_results,
|
1018
|
+
"combined_retrieval": combined_results,
|
1019
|
+
"final_answer": final_answer,
|
1020
|
+
"hyde_metadata": {
|
1021
|
+
"num_hypotheses": len(hypotheses),
|
1022
|
+
"successful_retrievals": len(
|
1023
|
+
[r for r in hypothesis_results if "error" not in r]
|
1024
|
+
),
|
1025
|
+
"total_unique_docs": len(
|
1026
|
+
set(
|
1027
|
+
doc.get("id", doc.get("content", "")[:50])
|
1028
|
+
for doc in combined_results.get("documents", [])
|
1029
|
+
)
|
1030
|
+
),
|
1031
|
+
"method": "HyDE",
|
1032
|
+
},
|
1033
|
+
}
|
1034
|
+
|
1035
|
+
def _initialize_components(self, config: Dict[str, Any]):
|
1036
|
+
"""Initialize hypothesis generator and base RAG"""
|
1037
|
+
if not self.hypothesis_generator:
|
1038
|
+
self.hypothesis_generator = LLMAgentNode(
|
1039
|
+
name=f"{self.name}_hypothesis_generator",
|
1040
|
+
model=self.hypothesis_model,
|
1041
|
+
provider="openai",
|
1042
|
+
system_prompt=self._get_hypothesis_generation_prompt(),
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
if not self.base_rag_workflow:
|
1046
|
+
rag_config = RAGConfig(**config) if config else RAGConfig()
|
1047
|
+
self.base_rag_workflow = create_hybrid_rag_workflow(rag_config)
|
1048
|
+
|
1049
|
+
def _get_hypothesis_generation_prompt(self) -> str:
|
1050
|
+
"""Get system prompt for hypothesis generation"""
|
1051
|
+
return f"""You are an expert answer generator for the HyDE (Hypothetical Document Embeddings) technique. Your job is to generate plausible, detailed hypothetical answers to queries.
|
1052
|
+
|
1053
|
+
These hypothetical answers will be used to find similar documents, so they should:
|
1054
|
+
|
1055
|
+
1. **Be Comprehensive**: Cover multiple aspects of the query
|
1056
|
+
2. **Use Domain Language**: Include terminology likely to appear in real documents
|
1057
|
+
3. **Be Specific**: Include concrete details, examples, and explanations
|
1058
|
+
4. **Vary in Approach**: If generating multiple hypotheses, use different angles
|
1059
|
+
|
1060
|
+
Generate {self.num_hypotheses if self.use_multiple_hypotheses else 1} hypothetical answer(s) that would be similar to documents containing the real answer.
|
1061
|
+
|
1062
|
+
Respond with JSON:
|
1063
|
+
{{
|
1064
|
+
"hypotheses": [
|
1065
|
+
"detailed hypothetical answer 1",
|
1066
|
+
{"additional hypotheses if multiple requested"}
|
1067
|
+
],
|
1068
|
+
"reasoning": "brief explanation of hypothesis strategy"
|
1069
|
+
}}"""
|
1070
|
+
|
1071
|
+
def _generate_hypotheses(self, query: str) -> List[str]:
|
1072
|
+
"""Generate hypothetical answers for the query"""
|
1073
|
+
try:
|
1074
|
+
hypothesis_input = f"""
|
1075
|
+
Query: {query}
|
1076
|
+
|
1077
|
+
Generate {self.num_hypotheses if self.use_multiple_hypotheses else 1} detailed hypothetical answer(s) that could help find relevant documents:
|
1078
|
+
"""
|
1079
|
+
|
1080
|
+
response = self.hypothesis_generator.run(
|
1081
|
+
messages=[{"role": "user", "content": hypothesis_input}]
|
1082
|
+
)
|
1083
|
+
|
1084
|
+
hypotheses = self._parse_hypotheses(response)
|
1085
|
+
logger.info(f"Generated {len(hypotheses)} hypotheses")
|
1086
|
+
return hypotheses
|
1087
|
+
|
1088
|
+
except Exception as e:
|
1089
|
+
logger.error(f"Hypothesis generation failed: {e}")
|
1090
|
+
# Fallback to simple hypothesis
|
1091
|
+
return [
|
1092
|
+
f"A comprehensive answer to '{query}' would include detailed explanations and examples."
|
1093
|
+
]
|
1094
|
+
|
1095
|
+
def _parse_hypotheses(self, response: Dict) -> List[str]:
|
1096
|
+
"""Parse hypotheses from LLM response"""
|
1097
|
+
try:
|
1098
|
+
content = response.get("content", "")
|
1099
|
+
if isinstance(content, list):
|
1100
|
+
content = content[0] if content else "{}"
|
1101
|
+
|
1102
|
+
# Extract JSON
|
1103
|
+
if "{" in content and "}" in content:
|
1104
|
+
json_start = content.find("{")
|
1105
|
+
json_end = content.rfind("}") + 1
|
1106
|
+
json_str = content[json_start:json_end]
|
1107
|
+
parsed = json.loads(json_str)
|
1108
|
+
|
1109
|
+
hypotheses = parsed.get("hypotheses", [])
|
1110
|
+
if hypotheses and isinstance(hypotheses, list):
|
1111
|
+
return hypotheses
|
1112
|
+
|
1113
|
+
# Fallback: treat entire content as single hypothesis
|
1114
|
+
return [content] if content else []
|
1115
|
+
|
1116
|
+
except Exception as e:
|
1117
|
+
logger.warning(f"Failed to parse hypotheses: {e}")
|
1118
|
+
return []
|
1119
|
+
|
1120
|
+
def _retrieve_with_hypothesis(
|
1121
|
+
self, hypothesis: str, documents: List[Dict], original_query: str
|
1122
|
+
) -> Dict[str, Any]:
|
1123
|
+
"""Retrieve documents using hypothesis as query"""
|
1124
|
+
# Use hypothesis as the retrieval query instead of original query
|
1125
|
+
result = self.base_rag_workflow.run(
|
1126
|
+
documents=documents,
|
1127
|
+
query=hypothesis, # Key difference: use hypothesis for retrieval
|
1128
|
+
operation="retrieve",
|
1129
|
+
)
|
1130
|
+
|
1131
|
+
# Add metadata about hypothesis-based retrieval
|
1132
|
+
result["hyde_metadata"] = {
|
1133
|
+
"hypothesis_used": hypothesis,
|
1134
|
+
"original_query": original_query,
|
1135
|
+
"retrieval_method": "hypothesis_embedding",
|
1136
|
+
}
|
1137
|
+
|
1138
|
+
return result
|
1139
|
+
|
1140
|
+
def _combine_hypothesis_results(
|
1141
|
+
self, hypothesis_results: List[Dict]
|
1142
|
+
) -> Dict[str, Any]:
|
1143
|
+
"""Combine results from multiple hypotheses"""
|
1144
|
+
all_docs = []
|
1145
|
+
all_scores = []
|
1146
|
+
doc_sources = {} # Track which hypothesis found each doc
|
1147
|
+
|
1148
|
+
for result_info in hypothesis_results:
|
1149
|
+
if "error" in result_info:
|
1150
|
+
continue
|
1151
|
+
|
1152
|
+
retrieval_result = result_info.get("retrieval_result", {})
|
1153
|
+
documents = retrieval_result.get("results", [])
|
1154
|
+
scores = retrieval_result.get("scores", [])
|
1155
|
+
hypothesis_idx = result_info.get("hypothesis_index", 0)
|
1156
|
+
|
1157
|
+
for doc, score in zip(documents, scores):
|
1158
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
1159
|
+
|
1160
|
+
# Track source hypothesis
|
1161
|
+
if doc_id not in doc_sources:
|
1162
|
+
doc_sources[doc_id] = []
|
1163
|
+
all_docs.append(doc)
|
1164
|
+
all_scores.append(score)
|
1165
|
+
|
1166
|
+
doc_sources[doc_id].append(
|
1167
|
+
{"hypothesis_index": hypothesis_idx, "score": score}
|
1168
|
+
)
|
1169
|
+
|
1170
|
+
# Add source information to documents
|
1171
|
+
for doc in all_docs:
|
1172
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
1173
|
+
doc["hyde_sources"] = doc_sources.get(doc_id, [])
|
1174
|
+
doc["source_diversity"] = len(doc_sources.get(doc_id, []))
|
1175
|
+
|
1176
|
+
# Sort by best score from any hypothesis
|
1177
|
+
doc_score_pairs = list(zip(all_docs, all_scores))
|
1178
|
+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
1179
|
+
|
1180
|
+
sorted_docs, sorted_scores = (
|
1181
|
+
zip(*doc_score_pairs) if doc_score_pairs else ([], [])
|
1182
|
+
)
|
1183
|
+
|
1184
|
+
return {
|
1185
|
+
"documents": list(sorted_docs),
|
1186
|
+
"scores": list(sorted_scores),
|
1187
|
+
"source_tracking": doc_sources,
|
1188
|
+
}
|
1189
|
+
|
1190
|
+
def _generate_final_answer(
|
1191
|
+
self, query: str, combined_results: Dict, hypotheses: List[str]
|
1192
|
+
) -> str:
|
1193
|
+
"""Generate final answer using retrieved documents"""
|
1194
|
+
documents = combined_results.get("documents", [])
|
1195
|
+
|
1196
|
+
if not documents:
|
1197
|
+
return f"No relevant documents found for query: {query}"
|
1198
|
+
|
1199
|
+
# Use top documents
|
1200
|
+
top_docs = documents[:5]
|
1201
|
+
|
1202
|
+
context_parts = []
|
1203
|
+
for i, doc in enumerate(top_docs):
|
1204
|
+
content = doc.get("content", "")[:300]
|
1205
|
+
source_info = doc.get("hyde_sources", [])
|
1206
|
+
diversity = doc.get("source_diversity", 0)
|
1207
|
+
|
1208
|
+
context_parts.append(
|
1209
|
+
f"Document {i+1} (found by {diversity} hypotheses): {content}..."
|
1210
|
+
)
|
1211
|
+
|
1212
|
+
context = "\n\n".join(context_parts)
|
1213
|
+
|
1214
|
+
return f"""Answer to '{query}' based on HyDE retrieval:
|
1215
|
+
|
1216
|
+
{context}
|
1217
|
+
|
1218
|
+
[Generated using {len(hypotheses)} hypothetical answers to improve document matching]"""
|
1219
|
+
|
1220
|
+
|
1221
|
+
@register_node()
|
1222
|
+
class StepBackRAGNode(Node):
|
1223
|
+
"""
|
1224
|
+
Step-Back Prompting for RAG
|
1225
|
+
|
1226
|
+
Generates abstract, higher-level questions to retrieve background information
|
1227
|
+
before addressing the specific query. Improves context and reasoning.
|
1228
|
+
|
1229
|
+
When to use:
|
1230
|
+
- Best for: "Why" questions, conceptual understanding, background needed
|
1231
|
+
- Not ideal for: Direct factual queries, simple lookups
|
1232
|
+
- Performance: ~1.5 seconds for dual retrieval
|
1233
|
+
- Context improvement: 30-50% better background coverage
|
1234
|
+
|
1235
|
+
Key features:
|
1236
|
+
- Abstract query generation
|
1237
|
+
- Dual retrieval (specific + abstract)
|
1238
|
+
- Weighted result combination
|
1239
|
+
- Context-aware answering
|
1240
|
+
|
1241
|
+
Example:
|
1242
|
+
step_back = StepBackRAGNode(
|
1243
|
+
abstraction_model="gpt-4"
|
1244
|
+
)
|
1245
|
+
|
1246
|
+
# Query: "Why does batch normalization help neural networks?"
|
1247
|
+
# Generates abstract: "What is normalization in machine learning?"
|
1248
|
+
# Retrieves:
|
1249
|
+
# - Specific docs about batch normalization benefits
|
1250
|
+
# - Abstract docs about normalization concepts
|
1251
|
+
# Combines both for comprehensive answer
|
1252
|
+
|
1253
|
+
result = await step_back.run(
|
1254
|
+
documents=documents,
|
1255
|
+
query="Why does batch normalization help neural networks?"
|
1256
|
+
)
|
1257
|
+
|
1258
|
+
Parameters:
|
1259
|
+
abstraction_model: LLM for abstract query generation
|
1260
|
+
abstraction_level: How abstract to make queries
|
1261
|
+
combination_weights: Balance of specific vs abstract
|
1262
|
+
include_reasoning: Add step-back reasoning to results
|
1263
|
+
|
1264
|
+
Returns:
|
1265
|
+
results: Combined specific and abstract documents
|
1266
|
+
specific_query: Original query
|
1267
|
+
abstract_query: Generated abstract version
|
1268
|
+
step_back_metadata: Abstraction quality and statistics
|
1269
|
+
reasoning_chain: How abstract helps answer specific
|
1270
|
+
"""
|
1271
|
+
|
1272
|
+
def __init__(self, name: str = "step_back_rag", abstraction_model: str = "gpt-4"):
|
1273
|
+
self.abstraction_model = abstraction_model
|
1274
|
+
self.abstraction_generator = None
|
1275
|
+
self.base_rag_workflow = None
|
1276
|
+
super().__init__(name)
|
1277
|
+
|
1278
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
1279
|
+
return {
|
1280
|
+
"documents": NodeParameter(
|
1281
|
+
name="documents",
|
1282
|
+
type=list,
|
1283
|
+
required=True,
|
1284
|
+
description="Documents for step-back RAG processing",
|
1285
|
+
),
|
1286
|
+
"query": NodeParameter(
|
1287
|
+
name="query",
|
1288
|
+
type=str,
|
1289
|
+
required=True,
|
1290
|
+
description="Specific query for step-back processing",
|
1291
|
+
),
|
1292
|
+
"config": NodeParameter(
|
1293
|
+
name="config",
|
1294
|
+
type=dict,
|
1295
|
+
required=False,
|
1296
|
+
description="RAG configuration parameters",
|
1297
|
+
),
|
1298
|
+
}
|
1299
|
+
|
1300
|
+
def run(self, **kwargs) -> Dict[str, Any]:
|
1301
|
+
"""Execute Step-Back RAG with abstract reasoning"""
|
1302
|
+
documents = kwargs.get("documents", [])
|
1303
|
+
specific_query = kwargs.get("query", "")
|
1304
|
+
config = kwargs.get("config", {})
|
1305
|
+
|
1306
|
+
# Initialize components
|
1307
|
+
self._initialize_components(config)
|
1308
|
+
|
1309
|
+
logger.info("Step-Back RAG processing specific query")
|
1310
|
+
|
1311
|
+
# Generate abstract (step-back) question
|
1312
|
+
abstract_query = self._generate_abstract_query(specific_query)
|
1313
|
+
|
1314
|
+
# Retrieve with both queries
|
1315
|
+
specific_results = self._retrieve_for_query(
|
1316
|
+
specific_query, documents, "specific"
|
1317
|
+
)
|
1318
|
+
abstract_results = self._retrieve_for_query(
|
1319
|
+
abstract_query, documents, "abstract"
|
1320
|
+
)
|
1321
|
+
|
1322
|
+
# Combine results with proper weighting
|
1323
|
+
combined_results = self._combine_step_back_results(
|
1324
|
+
specific_results, abstract_results, specific_query, abstract_query
|
1325
|
+
)
|
1326
|
+
|
1327
|
+
# Generate comprehensive answer
|
1328
|
+
final_answer = self._generate_step_back_answer(
|
1329
|
+
specific_query, abstract_query, combined_results
|
1330
|
+
)
|
1331
|
+
|
1332
|
+
return {
|
1333
|
+
"specific_query": specific_query,
|
1334
|
+
"abstract_query": abstract_query,
|
1335
|
+
"specific_retrieval": specific_results,
|
1336
|
+
"abstract_retrieval": abstract_results,
|
1337
|
+
"combined_results": combined_results,
|
1338
|
+
"final_answer": final_answer,
|
1339
|
+
"step_back_metadata": {
|
1340
|
+
"abstraction_successful": bool(
|
1341
|
+
abstract_query and abstract_query != specific_query
|
1342
|
+
),
|
1343
|
+
"specific_docs_count": len(specific_results.get("results", [])),
|
1344
|
+
"abstract_docs_count": len(abstract_results.get("results", [])),
|
1345
|
+
"combined_docs_count": len(combined_results.get("documents", [])),
|
1346
|
+
"method": "step_back_prompting",
|
1347
|
+
},
|
1348
|
+
}
|
1349
|
+
|
1350
|
+
def _initialize_components(self, config: Dict[str, Any]):
|
1351
|
+
"""Initialize abstraction generator and base RAG"""
|
1352
|
+
if not self.abstraction_generator:
|
1353
|
+
self.abstraction_generator = LLMAgentNode(
|
1354
|
+
name=f"{self.name}_abstraction_generator",
|
1355
|
+
model=self.abstraction_model,
|
1356
|
+
provider="openai",
|
1357
|
+
system_prompt=self._get_abstraction_prompt(),
|
1358
|
+
)
|
1359
|
+
|
1360
|
+
if not self.base_rag_workflow:
|
1361
|
+
rag_config = RAGConfig(**config) if config else RAGConfig()
|
1362
|
+
self.base_rag_workflow = create_hybrid_rag_workflow(rag_config)
|
1363
|
+
|
1364
|
+
def _get_abstraction_prompt(self) -> str:
|
1365
|
+
"""Get system prompt for step-back abstraction"""
|
1366
|
+
return """You are an expert at abstract reasoning and question formulation. Your job is to take specific, detailed questions and generate broader, more abstract versions that would help retrieve useful background information.
|
1367
|
+
|
1368
|
+
Step-Back Technique:
|
1369
|
+
1. **Identify Core Concepts**: What are the fundamental concepts in the question?
|
1370
|
+
2. **Generalize**: Create a broader question about those concepts
|
1371
|
+
3. **Background Focus**: The abstract question should retrieve foundational knowledge
|
1372
|
+
4. **Maintain Relevance**: Keep connection to original query intent
|
1373
|
+
|
1374
|
+
Examples:
|
1375
|
+
- Specific: "How does the gradient descent algorithm work in neural networks?"
|
1376
|
+
- Abstract: "What are the fundamental optimization techniques used in machine learning?"
|
1377
|
+
|
1378
|
+
- Specific: "What are the side effects of ibuprofen for children?"
|
1379
|
+
- Abstract: "What are the general principles of pediatric medication safety?"
|
1380
|
+
|
1381
|
+
Respond with JSON:
|
1382
|
+
{
|
1383
|
+
"abstract_query": "broader, more general version of the query",
|
1384
|
+
"reasoning": "explanation of abstraction strategy",
|
1385
|
+
"concepts_identified": ["list", "of", "core", "concepts"]
|
1386
|
+
}"""
|
1387
|
+
|
1388
|
+
def _generate_abstract_query(self, specific_query: str) -> str:
|
1389
|
+
"""Generate abstract step-back query"""
|
1390
|
+
try:
|
1391
|
+
abstraction_input = f"""
|
1392
|
+
Specific Query: {specific_query}
|
1393
|
+
|
1394
|
+
Generate a broader, more abstract version that would help retrieve relevant background information:
|
1395
|
+
"""
|
1396
|
+
|
1397
|
+
response = self.abstraction_generator.run(
|
1398
|
+
messages=[{"role": "user", "content": abstraction_input}]
|
1399
|
+
)
|
1400
|
+
|
1401
|
+
abstract_query = self._parse_abstract_query(response)
|
1402
|
+
logger.info(f"Generated abstract query: {abstract_query}")
|
1403
|
+
return abstract_query
|
1404
|
+
|
1405
|
+
except Exception as e:
|
1406
|
+
logger.error(f"Abstract query generation failed: {e}")
|
1407
|
+
# Fallback to simple abstraction
|
1408
|
+
return self._generate_fallback_abstraction(specific_query)
|
1409
|
+
|
1410
|
+
def _parse_abstract_query(self, response: Dict) -> str:
|
1411
|
+
"""Parse abstract query from LLM response"""
|
1412
|
+
try:
|
1413
|
+
content = response.get("content", "")
|
1414
|
+
if isinstance(content, list):
|
1415
|
+
content = content[0] if content else "{}"
|
1416
|
+
|
1417
|
+
# Extract JSON
|
1418
|
+
if "{" in content and "}" in content:
|
1419
|
+
json_start = content.find("{")
|
1420
|
+
json_end = content.rfind("}") + 1
|
1421
|
+
json_str = content[json_start:json_end]
|
1422
|
+
parsed = json.loads(json_str)
|
1423
|
+
|
1424
|
+
abstract_query = parsed.get("abstract_query", "")
|
1425
|
+
if abstract_query:
|
1426
|
+
return abstract_query
|
1427
|
+
|
1428
|
+
# Fallback: extract first question-like sentence
|
1429
|
+
sentences = content.split(".")
|
1430
|
+
for sentence in sentences:
|
1431
|
+
if (
|
1432
|
+
"?" in sentence
|
1433
|
+
or "what" in sentence.lower()
|
1434
|
+
or "how" in sentence.lower()
|
1435
|
+
):
|
1436
|
+
return sentence.strip()
|
1437
|
+
|
1438
|
+
return content.strip()
|
1439
|
+
|
1440
|
+
except Exception as e:
|
1441
|
+
logger.warning(f"Failed to parse abstract query: {e}")
|
1442
|
+
return content if isinstance(content, str) else ""
|
1443
|
+
|
1444
|
+
def _generate_fallback_abstraction(self, specific_query: str) -> str:
|
1445
|
+
"""Generate simple abstraction when LLM fails"""
|
1446
|
+
# Simple patterns for abstraction
|
1447
|
+
words = specific_query.lower().split()
|
1448
|
+
|
1449
|
+
if "how" in words:
|
1450
|
+
# "How does X work?" -> "What are the general principles of X?"
|
1451
|
+
return f"What are the general principles related to the topics in: {specific_query}"
|
1452
|
+
elif "what" in words:
|
1453
|
+
# "What is X?" -> "What are the broader concepts around X?"
|
1454
|
+
return f"What are the broader concepts and background for: {specific_query}"
|
1455
|
+
else:
|
1456
|
+
# Generic abstraction
|
1457
|
+
return f"What is the general background and context for: {specific_query}"
|
1458
|
+
|
1459
|
+
def _retrieve_for_query(
|
1460
|
+
self, query: str, documents: List[Dict], query_type: str
|
1461
|
+
) -> Dict[str, Any]:
|
1462
|
+
"""Retrieve documents for specific or abstract query"""
|
1463
|
+
result = self.base_rag_workflow.run(
|
1464
|
+
documents=documents, query=query, operation="retrieve"
|
1465
|
+
)
|
1466
|
+
|
1467
|
+
result["query_type"] = query_type
|
1468
|
+
result["query_used"] = query
|
1469
|
+
|
1470
|
+
return result
|
1471
|
+
|
1472
|
+
def _combine_step_back_results(
|
1473
|
+
self,
|
1474
|
+
specific_results: Dict,
|
1475
|
+
abstract_results: Dict,
|
1476
|
+
specific_query: str,
|
1477
|
+
abstract_query: str,
|
1478
|
+
) -> Dict[str, Any]:
|
1479
|
+
"""Combine specific and abstract retrieval results"""
|
1480
|
+
# Weight specific results higher (0.7) than abstract (0.3)
|
1481
|
+
specific_weight = 0.7
|
1482
|
+
abstract_weight = 0.3
|
1483
|
+
|
1484
|
+
combined_docs = []
|
1485
|
+
doc_sources = {}
|
1486
|
+
|
1487
|
+
# Add specific results with higher weight
|
1488
|
+
specific_docs = specific_results.get("results", [])
|
1489
|
+
specific_scores = specific_results.get("scores", [])
|
1490
|
+
|
1491
|
+
for doc, score in zip(specific_docs, specific_scores):
|
1492
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
1493
|
+
weighted_score = score * specific_weight
|
1494
|
+
|
1495
|
+
doc_with_metadata = doc.copy()
|
1496
|
+
doc_with_metadata["step_back_metadata"] = {
|
1497
|
+
"source_type": "specific",
|
1498
|
+
"original_score": score,
|
1499
|
+
"weighted_score": weighted_score,
|
1500
|
+
"source_query": specific_query,
|
1501
|
+
}
|
1502
|
+
|
1503
|
+
combined_docs.append((doc_with_metadata, weighted_score, doc_id))
|
1504
|
+
doc_sources[doc_id] = "specific"
|
1505
|
+
|
1506
|
+
# Add abstract results with lower weight (avoid duplicates)
|
1507
|
+
abstract_docs = abstract_results.get("results", [])
|
1508
|
+
abstract_scores = abstract_results.get("scores", [])
|
1509
|
+
|
1510
|
+
for doc, score in zip(abstract_docs, abstract_scores):
|
1511
|
+
doc_id = doc.get("id", doc.get("content", "")[:50])
|
1512
|
+
|
1513
|
+
# Skip if already added from specific results
|
1514
|
+
if doc_id in doc_sources:
|
1515
|
+
continue
|
1516
|
+
|
1517
|
+
weighted_score = score * abstract_weight
|
1518
|
+
|
1519
|
+
doc_with_metadata = doc.copy()
|
1520
|
+
doc_with_metadata["step_back_metadata"] = {
|
1521
|
+
"source_type": "abstract",
|
1522
|
+
"original_score": score,
|
1523
|
+
"weighted_score": weighted_score,
|
1524
|
+
"source_query": abstract_query,
|
1525
|
+
}
|
1526
|
+
|
1527
|
+
combined_docs.append((doc_with_metadata, weighted_score, doc_id))
|
1528
|
+
doc_sources[doc_id] = "abstract"
|
1529
|
+
|
1530
|
+
# Sort by weighted score
|
1531
|
+
combined_docs.sort(key=lambda x: x[1], reverse=True)
|
1532
|
+
|
1533
|
+
# Extract sorted documents and scores
|
1534
|
+
sorted_docs = [doc for doc, _, _ in combined_docs]
|
1535
|
+
sorted_scores = [score for _, score, _ in combined_docs]
|
1536
|
+
|
1537
|
+
return {
|
1538
|
+
"documents": sorted_docs,
|
1539
|
+
"scores": sorted_scores,
|
1540
|
+
"source_breakdown": {
|
1541
|
+
"specific_count": len(specific_docs),
|
1542
|
+
"abstract_count": len(abstract_docs),
|
1543
|
+
"total_unique": len(combined_docs),
|
1544
|
+
"weights_used": {
|
1545
|
+
"specific": specific_weight,
|
1546
|
+
"abstract": abstract_weight,
|
1547
|
+
},
|
1548
|
+
},
|
1549
|
+
}
|
1550
|
+
|
1551
|
+
def _generate_step_back_answer(
|
1552
|
+
self, specific_query: str, abstract_query: str, combined_results: Dict
|
1553
|
+
) -> str:
|
1554
|
+
"""Generate comprehensive answer using step-back approach"""
|
1555
|
+
documents = combined_results.get("documents", [])
|
1556
|
+
|
1557
|
+
if not documents:
|
1558
|
+
return f"No relevant documents found for query: {specific_query}"
|
1559
|
+
|
1560
|
+
# Separate background and specific information
|
1561
|
+
background_docs = [
|
1562
|
+
doc
|
1563
|
+
for doc in documents[:3]
|
1564
|
+
if doc.get("step_back_metadata", {}).get("source_type") == "abstract"
|
1565
|
+
]
|
1566
|
+
specific_docs = [
|
1567
|
+
doc
|
1568
|
+
for doc in documents[:5]
|
1569
|
+
if doc.get("step_back_metadata", {}).get("source_type") == "specific"
|
1570
|
+
]
|
1571
|
+
|
1572
|
+
# Build response with background context first
|
1573
|
+
response_parts = [f"Answer to: {specific_query}"]
|
1574
|
+
|
1575
|
+
if background_docs:
|
1576
|
+
response_parts.append("\nBackground Context:")
|
1577
|
+
for i, doc in enumerate(background_docs):
|
1578
|
+
content = doc.get("content", "")[:250]
|
1579
|
+
response_parts.append(f"Background {i+1}: {content}...")
|
1580
|
+
|
1581
|
+
if specific_docs:
|
1582
|
+
response_parts.append("\nSpecific Information:")
|
1583
|
+
for i, doc in enumerate(specific_docs):
|
1584
|
+
content = doc.get("content", "")[:300]
|
1585
|
+
response_parts.append(f"Specific {i+1}: {content}...")
|
1586
|
+
|
1587
|
+
response_parts.append(
|
1588
|
+
f"\n[Generated using step-back reasoning with abstract query: '{abstract_query}']"
|
1589
|
+
)
|
1590
|
+
|
1591
|
+
return "\n".join(response_parts)
|
1592
|
+
|
1593
|
+
|
1594
|
+
# Update the __init__.py to include new advanced nodes
|
1595
|
+
def update_init_file():
|
1596
|
+
"""Add new advanced RAG nodes to __init__.py"""
|
1597
|
+
new_imports = """
|
1598
|
+
from .advanced import (
|
1599
|
+
SelfCorrectingRAGNode,
|
1600
|
+
RAGFusionNode,
|
1601
|
+
HyDENode,
|
1602
|
+
StepBackRAGNode
|
1603
|
+
)
|
1604
|
+
"""
|
1605
|
+
|
1606
|
+
new_exports = """
|
1607
|
+
# Advanced RAG Techniques
|
1608
|
+
"SelfCorrectingRAGNode",
|
1609
|
+
"RAGFusionNode",
|
1610
|
+
"HyDENode",
|
1611
|
+
"StepBackRAGNode",
|
1612
|
+
"""
|
1613
|
+
|
1614
|
+
# This would be added to the existing __init__.py file
|
1615
|
+
return new_imports, new_exports
|