kailash 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +33 -1
- kailash/access_control/__init__.py +129 -0
- kailash/access_control/managers.py +461 -0
- kailash/access_control/rule_evaluators.py +467 -0
- kailash/access_control_abac.py +825 -0
- kailash/config/__init__.py +27 -0
- kailash/config/database_config.py +359 -0
- kailash/database/__init__.py +28 -0
- kailash/database/execution_pipeline.py +499 -0
- kailash/middleware/__init__.py +306 -0
- kailash/middleware/auth/__init__.py +33 -0
- kailash/middleware/auth/access_control.py +436 -0
- kailash/middleware/auth/auth_manager.py +422 -0
- kailash/middleware/auth/jwt_auth.py +477 -0
- kailash/middleware/auth/kailash_jwt_auth.py +616 -0
- kailash/middleware/communication/__init__.py +37 -0
- kailash/middleware/communication/ai_chat.py +989 -0
- kailash/middleware/communication/api_gateway.py +802 -0
- kailash/middleware/communication/events.py +470 -0
- kailash/middleware/communication/realtime.py +710 -0
- kailash/middleware/core/__init__.py +21 -0
- kailash/middleware/core/agent_ui.py +890 -0
- kailash/middleware/core/schema.py +643 -0
- kailash/middleware/core/workflows.py +396 -0
- kailash/middleware/database/__init__.py +63 -0
- kailash/middleware/database/base.py +113 -0
- kailash/middleware/database/base_models.py +525 -0
- kailash/middleware/database/enums.py +106 -0
- kailash/middleware/database/migrations.py +12 -0
- kailash/{api/database.py → middleware/database/models.py} +183 -291
- kailash/middleware/database/repositories.py +685 -0
- kailash/middleware/database/session_manager.py +19 -0
- kailash/middleware/mcp/__init__.py +38 -0
- kailash/middleware/mcp/client_integration.py +585 -0
- kailash/middleware/mcp/enhanced_server.py +576 -0
- kailash/nodes/__init__.py +27 -3
- kailash/nodes/admin/__init__.py +42 -0
- kailash/nodes/admin/audit_log.py +794 -0
- kailash/nodes/admin/permission_check.py +864 -0
- kailash/nodes/admin/role_management.py +823 -0
- kailash/nodes/admin/security_event.py +1523 -0
- kailash/nodes/admin/user_management.py +944 -0
- kailash/nodes/ai/a2a.py +24 -7
- kailash/nodes/ai/ai_providers.py +248 -40
- kailash/nodes/ai/embedding_generator.py +11 -11
- kailash/nodes/ai/intelligent_agent_orchestrator.py +99 -11
- kailash/nodes/ai/llm_agent.py +436 -5
- kailash/nodes/ai/self_organizing.py +85 -10
- kailash/nodes/ai/vision_utils.py +148 -0
- kailash/nodes/alerts/__init__.py +26 -0
- kailash/nodes/alerts/base.py +234 -0
- kailash/nodes/alerts/discord.py +499 -0
- kailash/nodes/api/auth.py +287 -6
- kailash/nodes/api/rest.py +151 -0
- kailash/nodes/auth/__init__.py +17 -0
- kailash/nodes/auth/directory_integration.py +1228 -0
- kailash/nodes/auth/enterprise_auth_provider.py +1328 -0
- kailash/nodes/auth/mfa.py +2338 -0
- kailash/nodes/auth/risk_assessment.py +872 -0
- kailash/nodes/auth/session_management.py +1093 -0
- kailash/nodes/auth/sso.py +1040 -0
- kailash/nodes/base.py +344 -13
- kailash/nodes/base_cycle_aware.py +4 -2
- kailash/nodes/base_with_acl.py +1 -1
- kailash/nodes/code/python.py +283 -10
- kailash/nodes/compliance/__init__.py +9 -0
- kailash/nodes/compliance/data_retention.py +1888 -0
- kailash/nodes/compliance/gdpr.py +2004 -0
- kailash/nodes/data/__init__.py +22 -2
- kailash/nodes/data/async_connection.py +469 -0
- kailash/nodes/data/async_sql.py +757 -0
- kailash/nodes/data/async_vector.py +598 -0
- kailash/nodes/data/readers.py +767 -0
- kailash/nodes/data/retrieval.py +360 -1
- kailash/nodes/data/sharepoint_graph.py +397 -21
- kailash/nodes/data/sql.py +94 -5
- kailash/nodes/data/streaming.py +68 -8
- kailash/nodes/data/vector_db.py +54 -4
- kailash/nodes/enterprise/__init__.py +13 -0
- kailash/nodes/enterprise/batch_processor.py +741 -0
- kailash/nodes/enterprise/data_lineage.py +497 -0
- kailash/nodes/logic/convergence.py +31 -9
- kailash/nodes/logic/operations.py +14 -3
- kailash/nodes/mixins/__init__.py +8 -0
- kailash/nodes/mixins/event_emitter.py +201 -0
- kailash/nodes/mixins/mcp.py +9 -4
- kailash/nodes/mixins/security.py +165 -0
- kailash/nodes/monitoring/__init__.py +7 -0
- kailash/nodes/monitoring/performance_benchmark.py +2497 -0
- kailash/nodes/rag/__init__.py +284 -0
- kailash/nodes/rag/advanced.py +1615 -0
- kailash/nodes/rag/agentic.py +773 -0
- kailash/nodes/rag/conversational.py +999 -0
- kailash/nodes/rag/evaluation.py +875 -0
- kailash/nodes/rag/federated.py +1188 -0
- kailash/nodes/rag/graph.py +721 -0
- kailash/nodes/rag/multimodal.py +671 -0
- kailash/nodes/rag/optimized.py +933 -0
- kailash/nodes/rag/privacy.py +1059 -0
- kailash/nodes/rag/query_processing.py +1335 -0
- kailash/nodes/rag/realtime.py +764 -0
- kailash/nodes/rag/registry.py +547 -0
- kailash/nodes/rag/router.py +837 -0
- kailash/nodes/rag/similarity.py +1854 -0
- kailash/nodes/rag/strategies.py +566 -0
- kailash/nodes/rag/workflows.py +575 -0
- kailash/nodes/security/__init__.py +19 -0
- kailash/nodes/security/abac_evaluator.py +1411 -0
- kailash/nodes/security/audit_log.py +103 -0
- kailash/nodes/security/behavior_analysis.py +1893 -0
- kailash/nodes/security/credential_manager.py +401 -0
- kailash/nodes/security/rotating_credentials.py +760 -0
- kailash/nodes/security/security_event.py +133 -0
- kailash/nodes/security/threat_detection.py +1103 -0
- kailash/nodes/testing/__init__.py +9 -0
- kailash/nodes/testing/credential_testing.py +499 -0
- kailash/nodes/transform/__init__.py +10 -2
- kailash/nodes/transform/chunkers.py +592 -1
- kailash/nodes/transform/processors.py +484 -14
- kailash/nodes/validation.py +321 -0
- kailash/runtime/access_controlled.py +1 -1
- kailash/runtime/async_local.py +41 -7
- kailash/runtime/docker.py +1 -1
- kailash/runtime/local.py +474 -55
- kailash/runtime/parallel.py +1 -1
- kailash/runtime/parallel_cyclic.py +1 -1
- kailash/runtime/testing.py +210 -2
- kailash/security.py +1 -1
- kailash/utils/migrations/__init__.py +25 -0
- kailash/utils/migrations/generator.py +433 -0
- kailash/utils/migrations/models.py +231 -0
- kailash/utils/migrations/runner.py +489 -0
- kailash/utils/secure_logging.py +342 -0
- kailash/workflow/__init__.py +16 -0
- kailash/workflow/cyclic_runner.py +3 -4
- kailash/workflow/graph.py +70 -2
- kailash/workflow/resilience.py +249 -0
- kailash/workflow/templates.py +726 -0
- {kailash-0.3.2.dist-info → kailash-0.4.1.dist-info}/METADATA +256 -20
- kailash-0.4.1.dist-info/RECORD +227 -0
- kailash/api/__init__.py +0 -17
- kailash/api/__main__.py +0 -6
- kailash/api/studio_secure.py +0 -893
- kailash/mcp/__main__.py +0 -13
- kailash/mcp/server_new.py +0 -336
- kailash/mcp/servers/__init__.py +0 -12
- kailash-0.3.2.dist-info/RECORD +0 -136
- {kailash-0.3.2.dist-info → kailash-0.4.1.dist-info}/WHEEL +0 -0
- {kailash-0.3.2.dist-info → kailash-0.4.1.dist-info}/entry_points.txt +0 -0
- {kailash-0.3.2.dist-info → kailash-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.3.2.dist-info → kailash-0.4.1.dist-info}/top_level.txt +0 -0
kailash/nodes/data/retrieval.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Document retrieval nodes for finding relevant content using various similarity methods."""
|
2
2
|
|
3
|
-
|
3
|
+
import json
|
4
|
+
from typing import Any, Dict, List, Optional
|
4
5
|
|
5
6
|
from kailash.nodes.base import Node, NodeParameter, register_node
|
6
7
|
|
@@ -107,10 +108,19 @@ class RelevanceScorerNode(Node):
|
|
107
108
|
# Handle query embedding - should be the first (and only) embedding in the list
|
108
109
|
query_embedding_obj = query_embeddings[0] if query_embeddings else {}
|
109
110
|
if isinstance(query_embedding_obj, dict) and "embedding" in query_embedding_obj:
|
111
|
+
# Handle Ollama format: {"embedding": [...]}
|
110
112
|
query_embedding = query_embedding_obj["embedding"]
|
113
|
+
elif (
|
114
|
+
isinstance(query_embedding_obj, dict)
|
115
|
+
and "embeddings" in query_embedding_obj
|
116
|
+
):
|
117
|
+
# Handle other provider formats: {"embeddings": [...]}
|
118
|
+
query_embedding = query_embedding_obj["embeddings"]
|
111
119
|
elif isinstance(query_embedding_obj, list):
|
120
|
+
# Handle direct list format
|
112
121
|
query_embedding = query_embedding_obj
|
113
122
|
else:
|
123
|
+
# Fallback
|
114
124
|
query_embedding = []
|
115
125
|
|
116
126
|
print(
|
@@ -149,10 +159,19 @@ class RelevanceScorerNode(Node):
|
|
149
159
|
isinstance(chunk_embedding_obj, dict)
|
150
160
|
and "embedding" in chunk_embedding_obj
|
151
161
|
):
|
162
|
+
# Handle Ollama format: {"embedding": [...]}
|
152
163
|
chunk_embedding = chunk_embedding_obj["embedding"]
|
164
|
+
elif (
|
165
|
+
isinstance(chunk_embedding_obj, dict)
|
166
|
+
and "embeddings" in chunk_embedding_obj
|
167
|
+
):
|
168
|
+
# Handle other provider formats: {"embeddings": [...]}
|
169
|
+
chunk_embedding = chunk_embedding_obj["embeddings"]
|
153
170
|
elif isinstance(chunk_embedding_obj, list):
|
171
|
+
# Handle direct list format
|
154
172
|
chunk_embedding = chunk_embedding_obj
|
155
173
|
else:
|
174
|
+
# Fallback
|
156
175
|
chunk_embedding = []
|
157
176
|
|
158
177
|
similarity = cosine_similarity(query_embedding, chunk_embedding)
|
@@ -176,3 +195,343 @@ class RelevanceScorerNode(Node):
|
|
176
195
|
# TODO: Implement TF-IDF scoring
|
177
196
|
# For now, return chunks with default scores
|
178
197
|
return [{**chunk, "relevance_score": 0.5} for chunk in chunks]
|
198
|
+
|
199
|
+
|
200
|
+
@register_node()
|
201
|
+
class HybridRetrieverNode(Node):
|
202
|
+
"""
|
203
|
+
Hybrid retrieval combining dense and sparse retrieval methods.
|
204
|
+
|
205
|
+
This node implements state-of-the-art hybrid retrieval that combines:
|
206
|
+
- Dense retrieval (semantic embeddings)
|
207
|
+
- Sparse retrieval (keyword-based like BM25)
|
208
|
+
- Multiple fusion strategies (RRF, linear combination, learned fusion)
|
209
|
+
|
210
|
+
Hybrid retrieval typically provides 20-30% better results than single methods.
|
211
|
+
"""
|
212
|
+
|
213
|
+
def __init__(self, name: str = "hybrid_retriever", **kwargs):
|
214
|
+
# Set attributes before calling super().__init__() as Kailash validates during init
|
215
|
+
self.fusion_strategy = kwargs.get(
|
216
|
+
"fusion_strategy", "rrf"
|
217
|
+
) # "rrf", "linear", "weighted"
|
218
|
+
self.dense_weight = kwargs.get("dense_weight", 0.6)
|
219
|
+
self.sparse_weight = kwargs.get("sparse_weight", 0.4)
|
220
|
+
self.rrf_k = kwargs.get("rrf_k", 60)
|
221
|
+
self.top_k = kwargs.get("top_k", 5)
|
222
|
+
self.normalize_scores = kwargs.get("normalize_scores", True)
|
223
|
+
|
224
|
+
super().__init__(name=name)
|
225
|
+
|
226
|
+
def get_parameters(self) -> dict[str, NodeParameter]:
|
227
|
+
return {
|
228
|
+
"query": NodeParameter(
|
229
|
+
name="query",
|
230
|
+
type=str,
|
231
|
+
required=True,
|
232
|
+
description="Search query",
|
233
|
+
),
|
234
|
+
"dense_results": NodeParameter(
|
235
|
+
name="dense_results",
|
236
|
+
type=list,
|
237
|
+
required=True,
|
238
|
+
description="Results from dense retrieval (with similarity_score)",
|
239
|
+
),
|
240
|
+
"sparse_results": NodeParameter(
|
241
|
+
name="sparse_results",
|
242
|
+
type=list,
|
243
|
+
required=True,
|
244
|
+
description="Results from sparse retrieval (with similarity_score)",
|
245
|
+
),
|
246
|
+
"fusion_strategy": NodeParameter(
|
247
|
+
name="fusion_strategy",
|
248
|
+
type=str,
|
249
|
+
required=False,
|
250
|
+
default=self.fusion_strategy,
|
251
|
+
description="Fusion strategy: rrf, linear, or weighted",
|
252
|
+
),
|
253
|
+
"dense_weight": NodeParameter(
|
254
|
+
name="dense_weight",
|
255
|
+
type=float,
|
256
|
+
required=False,
|
257
|
+
default=self.dense_weight,
|
258
|
+
description="Weight for dense retrieval scores (0.0-1.0)",
|
259
|
+
),
|
260
|
+
"sparse_weight": NodeParameter(
|
261
|
+
name="sparse_weight",
|
262
|
+
type=float,
|
263
|
+
required=False,
|
264
|
+
default=self.sparse_weight,
|
265
|
+
description="Weight for sparse retrieval scores (0.0-1.0)",
|
266
|
+
),
|
267
|
+
"top_k": NodeParameter(
|
268
|
+
name="top_k",
|
269
|
+
type=int,
|
270
|
+
required=False,
|
271
|
+
default=self.top_k,
|
272
|
+
description="Number of top results to return",
|
273
|
+
),
|
274
|
+
"rrf_k": NodeParameter(
|
275
|
+
name="rrf_k",
|
276
|
+
type=int,
|
277
|
+
required=False,
|
278
|
+
default=self.rrf_k,
|
279
|
+
description="RRF parameter k (higher = less aggressive fusion)",
|
280
|
+
),
|
281
|
+
}
|
282
|
+
|
283
|
+
def run(self, **kwargs) -> dict[str, Any]:
|
284
|
+
query = kwargs.get("query", "")
|
285
|
+
dense_results = kwargs.get("dense_results", [])
|
286
|
+
sparse_results = kwargs.get("sparse_results", [])
|
287
|
+
fusion_strategy = kwargs.get("fusion_strategy", self.fusion_strategy)
|
288
|
+
dense_weight = kwargs.get("dense_weight", self.dense_weight)
|
289
|
+
sparse_weight = kwargs.get("sparse_weight", self.sparse_weight)
|
290
|
+
top_k = kwargs.get("top_k", self.top_k)
|
291
|
+
rrf_k = kwargs.get("rrf_k", self.rrf_k)
|
292
|
+
|
293
|
+
if not dense_results and not sparse_results:
|
294
|
+
return {
|
295
|
+
"hybrid_results": [],
|
296
|
+
"fusion_method": fusion_strategy,
|
297
|
+
"dense_count": 0,
|
298
|
+
"sparse_count": 0,
|
299
|
+
"fused_count": 0,
|
300
|
+
}
|
301
|
+
|
302
|
+
# Ensure results have required fields
|
303
|
+
dense_results = self._normalize_results(dense_results, "dense")
|
304
|
+
sparse_results = self._normalize_results(sparse_results, "sparse")
|
305
|
+
|
306
|
+
# Apply fusion strategy
|
307
|
+
if fusion_strategy == "rrf":
|
308
|
+
fused_results = self._reciprocal_rank_fusion(
|
309
|
+
dense_results, sparse_results, top_k, rrf_k
|
310
|
+
)
|
311
|
+
elif fusion_strategy == "linear":
|
312
|
+
fused_results = self._linear_fusion(
|
313
|
+
dense_results, sparse_results, top_k, dense_weight, sparse_weight
|
314
|
+
)
|
315
|
+
elif fusion_strategy == "weighted":
|
316
|
+
fused_results = self._weighted_fusion(
|
317
|
+
dense_results, sparse_results, top_k, dense_weight, sparse_weight
|
318
|
+
)
|
319
|
+
else:
|
320
|
+
# Default to RRF
|
321
|
+
fused_results = self._reciprocal_rank_fusion(
|
322
|
+
dense_results, sparse_results, top_k, rrf_k
|
323
|
+
)
|
324
|
+
|
325
|
+
return {
|
326
|
+
"hybrid_results": fused_results,
|
327
|
+
"fusion_method": fusion_strategy,
|
328
|
+
"dense_count": len(dense_results),
|
329
|
+
"sparse_count": len(sparse_results),
|
330
|
+
"fused_count": len(fused_results),
|
331
|
+
}
|
332
|
+
|
333
|
+
def _normalize_results(self, results: List[Dict], source: str) -> List[Dict]:
|
334
|
+
"""Normalize results to ensure consistent format."""
|
335
|
+
normalized = []
|
336
|
+
|
337
|
+
for i, result in enumerate(results):
|
338
|
+
# Ensure required fields exist
|
339
|
+
normalized_result = {
|
340
|
+
"id": result.get("id", result.get("chunk_id", f"{source}_{i}")),
|
341
|
+
"content": result.get("content", result.get("text", "")),
|
342
|
+
"similarity_score": result.get(
|
343
|
+
"similarity_score", result.get("score", 0.0)
|
344
|
+
),
|
345
|
+
"source": source,
|
346
|
+
**result, # Keep original fields
|
347
|
+
}
|
348
|
+
normalized.append(normalized_result)
|
349
|
+
|
350
|
+
return normalized
|
351
|
+
|
352
|
+
def _reciprocal_rank_fusion(
|
353
|
+
self,
|
354
|
+
dense_results: List[Dict],
|
355
|
+
sparse_results: List[Dict],
|
356
|
+
top_k: int,
|
357
|
+
rrf_k: int,
|
358
|
+
) -> List[Dict]:
|
359
|
+
"""
|
360
|
+
Implement Reciprocal Rank Fusion (RRF).
|
361
|
+
|
362
|
+
RRF formula: RRF(d) = Σ(1 / (k + rank_i(d)))
|
363
|
+
where rank_i(d) is the rank of document d in ranklist i
|
364
|
+
"""
|
365
|
+
# Create rank mappings
|
366
|
+
dense_ranks = {doc["id"]: i + 1 for i, doc in enumerate(dense_results)}
|
367
|
+
sparse_ranks = {doc["id"]: i + 1 for i, doc in enumerate(sparse_results)}
|
368
|
+
|
369
|
+
# Collect all unique document IDs
|
370
|
+
all_doc_ids = set(dense_ranks.keys()) | set(sparse_ranks.keys())
|
371
|
+
|
372
|
+
# Calculate RRF scores
|
373
|
+
rrf_scores = {}
|
374
|
+
for doc_id in all_doc_ids:
|
375
|
+
score = 0.0
|
376
|
+
|
377
|
+
if doc_id in dense_ranks:
|
378
|
+
score += 1.0 / (rrf_k + dense_ranks[doc_id])
|
379
|
+
|
380
|
+
if doc_id in sparse_ranks:
|
381
|
+
score += 1.0 / (rrf_k + sparse_ranks[doc_id])
|
382
|
+
|
383
|
+
rrf_scores[doc_id] = score
|
384
|
+
|
385
|
+
# Sort by RRF score and get top-k
|
386
|
+
sorted_docs = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[
|
387
|
+
:top_k
|
388
|
+
]
|
389
|
+
|
390
|
+
# Build result documents
|
391
|
+
doc_map = {}
|
392
|
+
for doc in dense_results + sparse_results:
|
393
|
+
doc_map[doc["id"]] = doc
|
394
|
+
|
395
|
+
results = []
|
396
|
+
for doc_id, rrf_score in sorted_docs:
|
397
|
+
if doc_id in doc_map:
|
398
|
+
doc = doc_map[doc_id].copy()
|
399
|
+
doc["hybrid_score"] = rrf_score
|
400
|
+
doc["fusion_method"] = "rrf"
|
401
|
+
doc["rank"] = len(results) + 1
|
402
|
+
results.append(doc)
|
403
|
+
|
404
|
+
return results
|
405
|
+
|
406
|
+
def _linear_fusion(
|
407
|
+
self,
|
408
|
+
dense_results: List[Dict],
|
409
|
+
sparse_results: List[Dict],
|
410
|
+
top_k: int,
|
411
|
+
dense_weight: float,
|
412
|
+
sparse_weight: float,
|
413
|
+
) -> List[Dict]:
|
414
|
+
"""
|
415
|
+
Implement linear combination fusion.
|
416
|
+
|
417
|
+
Score = dense_weight * dense_score + sparse_weight * sparse_score
|
418
|
+
"""
|
419
|
+
if self.normalize_scores:
|
420
|
+
# Normalize scores to 0-1 range
|
421
|
+
dense_scores = [doc["similarity_score"] for doc in dense_results]
|
422
|
+
sparse_scores = [doc["similarity_score"] for doc in sparse_results]
|
423
|
+
|
424
|
+
dense_max = max(dense_scores) if dense_scores else 1.0
|
425
|
+
sparse_max = max(sparse_scores) if sparse_scores else 1.0
|
426
|
+
|
427
|
+
# Avoid division by zero
|
428
|
+
dense_max = max(dense_max, 1e-8)
|
429
|
+
sparse_max = max(sparse_max, 1e-8)
|
430
|
+
else:
|
431
|
+
dense_max = sparse_max = 1.0
|
432
|
+
|
433
|
+
# Create score mappings
|
434
|
+
dense_score_map = {
|
435
|
+
doc["id"]: doc["similarity_score"] / dense_max for doc in dense_results
|
436
|
+
}
|
437
|
+
sparse_score_map = {
|
438
|
+
doc["id"]: doc["similarity_score"] / sparse_max for doc in sparse_results
|
439
|
+
}
|
440
|
+
|
441
|
+
# Collect all unique document IDs
|
442
|
+
all_doc_ids = set(dense_score_map.keys()) | set(sparse_score_map.keys())
|
443
|
+
|
444
|
+
# Calculate linear combination scores
|
445
|
+
linear_scores = {}
|
446
|
+
for doc_id in all_doc_ids:
|
447
|
+
dense_score = dense_score_map.get(doc_id, 0.0)
|
448
|
+
sparse_score = sparse_score_map.get(doc_id, 0.0)
|
449
|
+
|
450
|
+
combined_score = dense_weight * dense_score + sparse_weight * sparse_score
|
451
|
+
linear_scores[doc_id] = combined_score
|
452
|
+
|
453
|
+
# Sort and build results
|
454
|
+
sorted_docs = sorted(linear_scores.items(), key=lambda x: x[1], reverse=True)[
|
455
|
+
:top_k
|
456
|
+
]
|
457
|
+
|
458
|
+
# Build result documents
|
459
|
+
doc_map = {}
|
460
|
+
for doc in dense_results + sparse_results:
|
461
|
+
doc_map[doc["id"]] = doc
|
462
|
+
|
463
|
+
results = []
|
464
|
+
for doc_id, combined_score in sorted_docs:
|
465
|
+
if doc_id in doc_map:
|
466
|
+
doc = doc_map[doc_id].copy()
|
467
|
+
doc["hybrid_score"] = combined_score
|
468
|
+
doc["fusion_method"] = "linear"
|
469
|
+
doc["rank"] = len(results) + 1
|
470
|
+
results.append(doc)
|
471
|
+
|
472
|
+
return results
|
473
|
+
|
474
|
+
def _weighted_fusion(
|
475
|
+
self,
|
476
|
+
dense_results: List[Dict],
|
477
|
+
sparse_results: List[Dict],
|
478
|
+
top_k: int,
|
479
|
+
dense_weight: float,
|
480
|
+
sparse_weight: float,
|
481
|
+
) -> List[Dict]:
|
482
|
+
"""
|
483
|
+
Implement weighted fusion with rank-based scoring.
|
484
|
+
|
485
|
+
Combines position-based weighting with score-based weighting.
|
486
|
+
"""
|
487
|
+
# Normalize weights
|
488
|
+
total_weight = dense_weight + sparse_weight
|
489
|
+
if total_weight > 0:
|
490
|
+
dense_weight = dense_weight / total_weight
|
491
|
+
sparse_weight = sparse_weight / total_weight
|
492
|
+
else:
|
493
|
+
dense_weight = sparse_weight = 0.5
|
494
|
+
|
495
|
+
# Calculate weighted scores
|
496
|
+
weighted_scores = {}
|
497
|
+
|
498
|
+
# Process dense results
|
499
|
+
for i, doc in enumerate(dense_results):
|
500
|
+
doc_id = doc["id"]
|
501
|
+
# Combine similarity score with rank-based discount
|
502
|
+
rank_score = 1.0 / (i + 1) # Higher ranks get higher scores
|
503
|
+
weighted_score = dense_weight * (
|
504
|
+
doc["similarity_score"] * 0.7 + rank_score * 0.3
|
505
|
+
)
|
506
|
+
weighted_scores[doc_id] = weighted_scores.get(doc_id, 0.0) + weighted_score
|
507
|
+
|
508
|
+
# Process sparse results
|
509
|
+
for i, doc in enumerate(sparse_results):
|
510
|
+
doc_id = doc["id"]
|
511
|
+
# Combine similarity score with rank-based discount
|
512
|
+
rank_score = 1.0 / (i + 1) # Higher ranks get higher scores
|
513
|
+
weighted_score = sparse_weight * (
|
514
|
+
doc["similarity_score"] * 0.7 + rank_score * 0.3
|
515
|
+
)
|
516
|
+
weighted_scores[doc_id] = weighted_scores.get(doc_id, 0.0) + weighted_score
|
517
|
+
|
518
|
+
# Sort and build results
|
519
|
+
sorted_docs = sorted(weighted_scores.items(), key=lambda x: x[1], reverse=True)[
|
520
|
+
:top_k
|
521
|
+
]
|
522
|
+
|
523
|
+
# Build result documents
|
524
|
+
doc_map = {}
|
525
|
+
for doc in dense_results + sparse_results:
|
526
|
+
doc_map[doc["id"]] = doc
|
527
|
+
|
528
|
+
results = []
|
529
|
+
for doc_id, weighted_score in sorted_docs:
|
530
|
+
if doc_id in doc_map:
|
531
|
+
doc = doc_map[doc_id].copy()
|
532
|
+
doc["hybrid_score"] = weighted_score
|
533
|
+
doc["fusion_method"] = "weighted"
|
534
|
+
doc["rank"] = len(results) + 1
|
535
|
+
results.append(doc)
|
536
|
+
|
537
|
+
return results
|