aiecs 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiecs/__init__.py +72 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +469 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +363 -0
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +375 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +356 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +531 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +443 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +319 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +100 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +327 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +349 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +244 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +23 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +387 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +343 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +580 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +189 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +344 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +378 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +199 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +347 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +504 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +167 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +630 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +654 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +477 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +390 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +217 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +169 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +872 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +554 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +19 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +596 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +423 -0
- aiecs/application/knowledge_graph/search/reranker.py +295 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +553 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +398 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +329 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +269 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +189 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +321 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +484 -0
- aiecs/config/__init__.py +16 -0
- aiecs/config/config.py +498 -0
- aiecs/config/graph_config.py +137 -0
- aiecs/config/registry.py +23 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +152 -0
- aiecs/core/interface/storage_interface.py +171 -0
- aiecs/domain/__init__.py +289 -0
- aiecs/domain/agent/__init__.py +189 -0
- aiecs/domain/agent/base_agent.py +697 -0
- aiecs/domain/agent/exceptions.py +103 -0
- aiecs/domain/agent/graph_aware_mixin.py +559 -0
- aiecs/domain/agent/hybrid_agent.py +490 -0
- aiecs/domain/agent/integration/__init__.py +26 -0
- aiecs/domain/agent/integration/context_compressor.py +222 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +252 -0
- aiecs/domain/agent/integration/retry_policy.py +219 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +646 -0
- aiecs/domain/agent/lifecycle.py +296 -0
- aiecs/domain/agent/llm_agent.py +300 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +197 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +160 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +90 -0
- aiecs/domain/agent/models.py +317 -0
- aiecs/domain/agent/observability.py +407 -0
- aiecs/domain/agent/persistence.py +289 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +161 -0
- aiecs/domain/agent/prompts/formatters.py +189 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +260 -0
- aiecs/domain/agent/tool_agent.py +257 -0
- aiecs/domain/agent/tools/__init__.py +12 -0
- aiecs/domain/agent/tools/schema_generator.py +221 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +477 -0
- aiecs/domain/community/analytics.py +481 -0
- aiecs/domain/community/collaborative_workflow.py +642 -0
- aiecs/domain/community/communication_hub.py +645 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +800 -0
- aiecs/domain/community/community_manager.py +813 -0
- aiecs/domain/community/decision_engine.py +879 -0
- aiecs/domain/community/exceptions.py +225 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +268 -0
- aiecs/domain/community/resource_manager.py +457 -0
- aiecs/domain/community/shared_context_manager.py +603 -0
- aiecs/domain/context/__init__.py +58 -0
- aiecs/domain/context/context_engine.py +989 -0
- aiecs/domain/context/conversation_models.py +354 -0
- aiecs/domain/context/graph_memory.py +467 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +57 -0
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +130 -0
- aiecs/domain/knowledge_graph/models/evidence.py +194 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +186 -0
- aiecs/domain/knowledge_graph/models/path.py +179 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +173 -0
- aiecs/domain/knowledge_graph/models/query.py +272 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +187 -0
- aiecs/domain/knowledge_graph/models/relation.py +136 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +135 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +271 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +155 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +171 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +496 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +205 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +613 -0
- aiecs/domain/task/model.py +62 -0
- aiecs/domain/task/task_context.py +268 -0
- aiecs/infrastructure/__init__.py +24 -0
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +601 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +449 -0
- aiecs/infrastructure/graph_storage/cache.py +429 -0
- aiecs/infrastructure/graph_storage/distributed.py +226 -0
- aiecs/infrastructure/graph_storage/error_handling.py +390 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +306 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +514 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +483 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +410 -0
- aiecs/infrastructure/graph_storage/metrics.py +357 -0
- aiecs/infrastructure/graph_storage/migration.py +413 -0
- aiecs/infrastructure/graph_storage/pagination.py +471 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +466 -0
- aiecs/infrastructure/graph_storage/postgres.py +871 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +635 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +623 -0
- aiecs/infrastructure/graph_storage/streaming.py +495 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +383 -0
- aiecs/infrastructure/messaging/websocket_manager.py +298 -0
- aiecs/infrastructure/monitoring/__init__.py +34 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +174 -0
- aiecs/infrastructure/monitoring/global_metrics_manager.py +213 -0
- aiecs/infrastructure/monitoring/structured_logger.py +48 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +410 -0
- aiecs/infrastructure/persistence/__init__.py +24 -0
- aiecs/infrastructure/persistence/context_engine_client.py +187 -0
- aiecs/infrastructure/persistence/database_manager.py +333 -0
- aiecs/infrastructure/persistence/file_storage.py +754 -0
- aiecs/infrastructure/persistence/redis_client.py +220 -0
- aiecs/llm/__init__.py +86 -0
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/callbacks/custom_callbacks.py +264 -0
- aiecs/llm/client_factory.py +420 -0
- aiecs/llm/clients/__init__.py +33 -0
- aiecs/llm/clients/base_client.py +193 -0
- aiecs/llm/clients/googleai_client.py +181 -0
- aiecs/llm/clients/openai_client.py +131 -0
- aiecs/llm/clients/vertex_client.py +437 -0
- aiecs/llm/clients/xai_client.py +184 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +275 -0
- aiecs/llm/config/config_validator.py +236 -0
- aiecs/llm/config/model_config.py +151 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +91 -0
- aiecs/main.py +363 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +97 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/version_manager.py +215 -0
- aiecs/scripts/dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md +242 -0
- aiecs/scripts/dependance_check/README_DEPENDENCY_CHECKER.md +310 -0
- aiecs/scripts/dependance_check/__init__.py +17 -0
- aiecs/scripts/dependance_check/dependency_checker.py +938 -0
- aiecs/scripts/dependance_check/dependency_fixer.py +391 -0
- aiecs/scripts/dependance_check/download_nlp_data.py +396 -0
- aiecs/scripts/dependance_check/quick_dependency_check.py +270 -0
- aiecs/scripts/dependance_check/setup_nlp_data.sh +217 -0
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.py +128 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.sh +82 -0
- aiecs/scripts/dependance_patch/fix_weasel/patch_weasel_library.sh +188 -0
- aiecs/scripts/dependance_patch/fix_weasel/run_weasel_patch.sh +41 -0
- aiecs/scripts/tools_develop/README.md +449 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +259 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +422 -0
- aiecs/scripts/tools_develop/verify_tools.py +356 -0
- aiecs/tasks/__init__.py +1 -0
- aiecs/tasks/worker.py +172 -0
- aiecs/tools/__init__.py +299 -0
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +381 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +413 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +388 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +303 -0
- aiecs/tools/apisource/providers/__init__.py +115 -0
- aiecs/tools/apisource/providers/base.py +664 -0
- aiecs/tools/apisource/providers/census.py +401 -0
- aiecs/tools/apisource/providers/fred.py +564 -0
- aiecs/tools/apisource/providers/newsapi.py +412 -0
- aiecs/tools/apisource/providers/worldbank.py +357 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +375 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +391 -0
- aiecs/tools/apisource/tool.py +850 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +338 -0
- aiecs/tools/base_tool.py +201 -0
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +599 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2403 -0
- aiecs/tools/docs/content_insertion_tool.py +1333 -0
- aiecs/tools/docs/document_creator_tool.py +1317 -0
- aiecs/tools/docs/document_layout_tool.py +1166 -0
- aiecs/tools/docs/document_parser_tool.py +994 -0
- aiecs/tools/docs/document_writer_tool.py +1818 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +734 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +923 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +476 -0
- aiecs/tools/langchain_adapter.py +542 -0
- aiecs/tools/schema_generator.py +275 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +589 -0
- aiecs/tools/search_tool/cache.py +260 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +216 -0
- aiecs/tools/search_tool/core.py +749 -0
- aiecs/tools/search_tool/deduplicator.py +123 -0
- aiecs/tools/search_tool/error_handler.py +271 -0
- aiecs/tools/search_tool/metrics.py +371 -0
- aiecs/tools/search_tool/rate_limiter.py +178 -0
- aiecs/tools/search_tool/schemas.py +277 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +643 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +505 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +694 -0
- aiecs/tools/statistics/data_loader_tool.py +564 -0
- aiecs/tools/statistics/data_profiler_tool.py +658 -0
- aiecs/tools/statistics/data_transformer_tool.py +573 -0
- aiecs/tools/statistics/data_visualizer_tool.py +495 -0
- aiecs/tools/statistics/model_trainer_tool.py +487 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +459 -0
- aiecs/tools/task_tools/__init__.py +86 -0
- aiecs/tools/task_tools/chart_tool.py +732 -0
- aiecs/tools/task_tools/classfire_tool.py +922 -0
- aiecs/tools/task_tools/image_tool.py +447 -0
- aiecs/tools/task_tools/office_tool.py +684 -0
- aiecs/tools/task_tools/pandas_tool.py +635 -0
- aiecs/tools/task_tools/report_tool.py +635 -0
- aiecs/tools/task_tools/research_tool.py +392 -0
- aiecs/tools/task_tools/scraper_tool.py +715 -0
- aiecs/tools/task_tools/stats_tool.py +688 -0
- aiecs/tools/temp_file_manager.py +130 -0
- aiecs/tools/tool_executor/__init__.py +37 -0
- aiecs/tools/tool_executor/tool_executor.py +881 -0
- aiecs/utils/LLM_output_structor.py +445 -0
- aiecs/utils/__init__.py +34 -0
- aiecs/utils/base_callback.py +47 -0
- aiecs/utils/cache_provider.py +695 -0
- aiecs/utils/execution_utils.py +184 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +14 -0
- aiecs/utils/token_usage_repository.py +323 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +52 -0
- aiecs-1.5.1.dist-info/METADATA +608 -0
- aiecs-1.5.1.dist-info/RECORD +302 -0
- aiecs-1.5.1.dist-info/WHEEL +5 -0
- aiecs-1.5.1.dist-info/entry_points.txt +10 -0
- aiecs-1.5.1.dist-info/licenses/LICENSE +225 -0
- aiecs-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hybrid Search Strategy
|
|
3
|
+
|
|
4
|
+
Combines vector similarity search with graph structure traversal
|
|
5
|
+
to provide enhanced search results.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Optional, Dict, Tuple
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
from aiecs.domain.knowledge_graph.models.entity import Entity
|
|
12
|
+
from aiecs.domain.knowledge_graph.models.path import Path
|
|
13
|
+
from aiecs.infrastructure.graph_storage.base import GraphStore
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SearchMode(str, Enum):
|
|
17
|
+
"""Search mode for hybrid search"""
|
|
18
|
+
|
|
19
|
+
VECTOR_ONLY = "vector_only"
|
|
20
|
+
GRAPH_ONLY = "graph_only"
|
|
21
|
+
HYBRID = "hybrid"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HybridSearchConfig(BaseModel):
|
|
25
|
+
"""
|
|
26
|
+
Configuration for hybrid search
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
mode: Search mode (vector_only, graph_only, hybrid)
|
|
30
|
+
vector_weight: Weight for vector similarity scores (0.0-1.0)
|
|
31
|
+
graph_weight: Weight for graph structure scores (0.0-1.0)
|
|
32
|
+
max_results: Maximum number of results to return
|
|
33
|
+
vector_threshold: Minimum similarity threshold for vector search
|
|
34
|
+
max_graph_depth: Maximum depth for graph traversal
|
|
35
|
+
expand_results: Whether to expand vector results with graph neighbors
|
|
36
|
+
min_combined_score: Minimum combined score threshold
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
mode: SearchMode = Field(default=SearchMode.HYBRID, description="Search mode")
|
|
40
|
+
|
|
41
|
+
vector_weight: float = Field(
|
|
42
|
+
default=0.6,
|
|
43
|
+
ge=0.0,
|
|
44
|
+
le=1.0,
|
|
45
|
+
description="Weight for vector similarity scores",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
graph_weight: float = Field(
|
|
49
|
+
default=0.4,
|
|
50
|
+
ge=0.0,
|
|
51
|
+
le=1.0,
|
|
52
|
+
description="Weight for graph structure scores",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
max_results: int = Field(default=10, ge=1, description="Maximum number of results")
|
|
56
|
+
|
|
57
|
+
vector_threshold: float = Field(
|
|
58
|
+
default=0.0,
|
|
59
|
+
ge=0.0,
|
|
60
|
+
le=1.0,
|
|
61
|
+
description="Minimum similarity threshold for vector search",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
max_graph_depth: int = Field(
|
|
65
|
+
default=2, ge=1, le=5, description="Maximum depth for graph traversal"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
expand_results: bool = Field(
|
|
69
|
+
default=True,
|
|
70
|
+
description="Whether to expand vector results with graph neighbors",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
min_combined_score: float = Field(
|
|
74
|
+
default=0.0,
|
|
75
|
+
ge=0.0,
|
|
76
|
+
le=1.0,
|
|
77
|
+
description="Minimum combined score threshold",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
entity_type_filter: Optional[str] = Field(
|
|
81
|
+
default=None, description="Optional entity type filter"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
class Config:
|
|
85
|
+
use_enum_values = True
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class HybridSearchStrategy:
|
|
89
|
+
"""
|
|
90
|
+
Hybrid Search Strategy
|
|
91
|
+
|
|
92
|
+
Combines vector similarity search with graph structure traversal
|
|
93
|
+
to provide enhanced search results that leverage both semantic
|
|
94
|
+
similarity and structural relationships.
|
|
95
|
+
|
|
96
|
+
Search Modes:
|
|
97
|
+
- VECTOR_ONLY: Pure vector similarity search
|
|
98
|
+
- GRAPH_ONLY: Pure graph traversal from seed entities
|
|
99
|
+
- HYBRID: Combines both approaches with weighted scoring
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
```python
|
|
103
|
+
strategy = HybridSearchStrategy(graph_store)
|
|
104
|
+
|
|
105
|
+
config = HybridSearchConfig(
|
|
106
|
+
mode=SearchMode.HYBRID,
|
|
107
|
+
vector_weight=0.6,
|
|
108
|
+
graph_weight=0.4,
|
|
109
|
+
max_results=10,
|
|
110
|
+
expand_results=True
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
results = await strategy.search(
|
|
114
|
+
query_embedding=[0.1, 0.2, ...],
|
|
115
|
+
config=config
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
for entity, score in results:
|
|
119
|
+
print(f"{entity.id}: {score:.3f}")
|
|
120
|
+
```
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, graph_store: GraphStore):
|
|
124
|
+
"""
|
|
125
|
+
Initialize hybrid search strategy
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
graph_store: Graph storage backend
|
|
129
|
+
"""
|
|
130
|
+
self.graph_store = graph_store
|
|
131
|
+
|
|
132
|
+
async def search(
|
|
133
|
+
self,
|
|
134
|
+
query_embedding: List[float],
|
|
135
|
+
config: Optional[HybridSearchConfig] = None,
|
|
136
|
+
seed_entity_ids: Optional[List[str]] = None,
|
|
137
|
+
) -> List[Tuple[Entity, float]]:
|
|
138
|
+
"""
|
|
139
|
+
Perform hybrid search
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
query_embedding: Query vector embedding
|
|
143
|
+
config: Search configuration (uses defaults if None)
|
|
144
|
+
seed_entity_ids: Optional seed entities for graph traversal
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
List of (entity, score) tuples sorted by score descending
|
|
148
|
+
"""
|
|
149
|
+
if config is None:
|
|
150
|
+
config = HybridSearchConfig()
|
|
151
|
+
|
|
152
|
+
if config.mode == SearchMode.VECTOR_ONLY:
|
|
153
|
+
return await self._vector_search(query_embedding, config)
|
|
154
|
+
elif config.mode == SearchMode.GRAPH_ONLY:
|
|
155
|
+
if not seed_entity_ids:
|
|
156
|
+
# If no seeds provided, do vector search to find seeds
|
|
157
|
+
vector_results = await self._vector_search(query_embedding, config, max_results=5)
|
|
158
|
+
seed_entity_ids = [entity.id for entity, _ in vector_results]
|
|
159
|
+
return await self._graph_search(seed_entity_ids, config)
|
|
160
|
+
else: # HYBRID
|
|
161
|
+
return await self._hybrid_search(query_embedding, config, seed_entity_ids)
|
|
162
|
+
|
|
163
|
+
async def _vector_search(
|
|
164
|
+
self,
|
|
165
|
+
query_embedding: List[float],
|
|
166
|
+
config: HybridSearchConfig,
|
|
167
|
+
max_results: Optional[int] = None,
|
|
168
|
+
) -> List[Tuple[Entity, float]]:
|
|
169
|
+
"""
|
|
170
|
+
Perform vector similarity search
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
query_embedding: Query vector
|
|
174
|
+
config: Search configuration
|
|
175
|
+
max_results: Optional override for max results
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
List of (entity, score) tuples
|
|
179
|
+
"""
|
|
180
|
+
results = await self.graph_store.vector_search(
|
|
181
|
+
query_embedding=query_embedding,
|
|
182
|
+
entity_type=config.entity_type_filter,
|
|
183
|
+
max_results=max_results or config.max_results,
|
|
184
|
+
score_threshold=config.vector_threshold,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return results
|
|
188
|
+
|
|
189
|
+
async def _graph_search(
|
|
190
|
+
self, seed_entity_ids: List[str], config: HybridSearchConfig
|
|
191
|
+
) -> List[Tuple[Entity, float]]:
|
|
192
|
+
"""
|
|
193
|
+
Perform graph structure search from seed entities
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
seed_entity_ids: Starting entities for traversal
|
|
197
|
+
config: Search configuration
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
List of (entity, score) tuples
|
|
201
|
+
"""
|
|
202
|
+
# Collect entities from graph traversal
|
|
203
|
+
entity_scores: Dict[str, float] = {}
|
|
204
|
+
|
|
205
|
+
for seed_id in seed_entity_ids:
|
|
206
|
+
# Get neighbors at different depths
|
|
207
|
+
current_entities = {seed_id}
|
|
208
|
+
visited = set()
|
|
209
|
+
|
|
210
|
+
for depth in range(config.max_graph_depth):
|
|
211
|
+
next_entities = set()
|
|
212
|
+
|
|
213
|
+
for entity_id in current_entities:
|
|
214
|
+
if entity_id in visited:
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
visited.add(entity_id)
|
|
218
|
+
|
|
219
|
+
# Score decreases with depth
|
|
220
|
+
depth_score = 1.0 / (depth + 1)
|
|
221
|
+
|
|
222
|
+
# Update score (take max if entity seen from multiple
|
|
223
|
+
# paths)
|
|
224
|
+
if entity_id not in entity_scores:
|
|
225
|
+
entity_scores[entity_id] = depth_score
|
|
226
|
+
else:
|
|
227
|
+
entity_scores[entity_id] = max(entity_scores[entity_id], depth_score)
|
|
228
|
+
|
|
229
|
+
# Get neighbors for next depth
|
|
230
|
+
neighbors = await self.graph_store.get_neighbors(
|
|
231
|
+
entity_id, direction="outgoing"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
for neighbor in neighbors:
|
|
235
|
+
if neighbor.id not in visited:
|
|
236
|
+
next_entities.add(neighbor.id)
|
|
237
|
+
|
|
238
|
+
current_entities = next_entities
|
|
239
|
+
|
|
240
|
+
if not current_entities:
|
|
241
|
+
break
|
|
242
|
+
|
|
243
|
+
# Retrieve entities and create result list
|
|
244
|
+
results = []
|
|
245
|
+
for entity_id, score in entity_scores.items():
|
|
246
|
+
entity = await self.graph_store.get_entity(entity_id)
|
|
247
|
+
if entity:
|
|
248
|
+
# Apply entity type filter if specified
|
|
249
|
+
if config.entity_type_filter:
|
|
250
|
+
if entity.entity_type == config.entity_type_filter:
|
|
251
|
+
results.append((entity, score))
|
|
252
|
+
else:
|
|
253
|
+
results.append((entity, score))
|
|
254
|
+
|
|
255
|
+
# Sort by score descending
|
|
256
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
257
|
+
|
|
258
|
+
# Return top results
|
|
259
|
+
return results[: config.max_results]
|
|
260
|
+
|
|
261
|
+
async def _hybrid_search(
|
|
262
|
+
self,
|
|
263
|
+
query_embedding: List[float],
|
|
264
|
+
config: HybridSearchConfig,
|
|
265
|
+
seed_entity_ids: Optional[List[str]] = None,
|
|
266
|
+
) -> List[Tuple[Entity, float]]:
|
|
267
|
+
"""
|
|
268
|
+
Perform hybrid search combining vector and graph
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
query_embedding: Query vector
|
|
272
|
+
config: Search configuration
|
|
273
|
+
seed_entity_ids: Optional seed entities
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
List of (entity, score) tuples with combined scores
|
|
277
|
+
"""
|
|
278
|
+
# Step 1: Vector search
|
|
279
|
+
vector_results = await self._vector_search(
|
|
280
|
+
query_embedding,
|
|
281
|
+
config,
|
|
282
|
+
max_results=config.max_results * 2, # Get more for expansion
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Create score dictionaries
|
|
286
|
+
vector_scores: Dict[str, float] = {entity.id: score for entity, score in vector_results}
|
|
287
|
+
|
|
288
|
+
# Step 2: Graph expansion (if enabled)
|
|
289
|
+
graph_scores: Dict[str, float] = {}
|
|
290
|
+
|
|
291
|
+
if config.expand_results:
|
|
292
|
+
# Use top vector results as seeds
|
|
293
|
+
seeds = seed_entity_ids or [entity.id for entity, _ in vector_results[:5]]
|
|
294
|
+
|
|
295
|
+
graph_results = await self._graph_search(seeds, config)
|
|
296
|
+
graph_scores = {entity.id: score for entity, score in graph_results}
|
|
297
|
+
|
|
298
|
+
# Step 3: Combine scores
|
|
299
|
+
combined_scores = await self._combine_scores(vector_scores, graph_scores, config)
|
|
300
|
+
|
|
301
|
+
# Step 4: Retrieve entities and create results
|
|
302
|
+
results = []
|
|
303
|
+
for entity_id, combined_score in combined_scores.items():
|
|
304
|
+
# Apply minimum score threshold
|
|
305
|
+
if combined_score < config.min_combined_score:
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
entity = await self.graph_store.get_entity(entity_id)
|
|
309
|
+
if entity:
|
|
310
|
+
results.append((entity, combined_score))
|
|
311
|
+
|
|
312
|
+
# Sort by combined score descending
|
|
313
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
314
|
+
|
|
315
|
+
# Return top results
|
|
316
|
+
return results[: config.max_results]
|
|
317
|
+
|
|
318
|
+
async def _combine_scores(
|
|
319
|
+
self,
|
|
320
|
+
vector_scores: Dict[str, float],
|
|
321
|
+
graph_scores: Dict[str, float],
|
|
322
|
+
config: HybridSearchConfig,
|
|
323
|
+
) -> Dict[str, float]:
|
|
324
|
+
"""
|
|
325
|
+
Combine vector and graph scores with weighted averaging
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
vector_scores: Entity ID to vector similarity score
|
|
329
|
+
graph_scores: Entity ID to graph structure score
|
|
330
|
+
config: Search configuration
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Combined scores dictionary
|
|
334
|
+
"""
|
|
335
|
+
# Normalize weights
|
|
336
|
+
total_weight = config.vector_weight + config.graph_weight
|
|
337
|
+
if total_weight == 0:
|
|
338
|
+
total_weight = 1.0
|
|
339
|
+
|
|
340
|
+
norm_vector_weight = config.vector_weight / total_weight
|
|
341
|
+
norm_graph_weight = config.graph_weight / total_weight
|
|
342
|
+
|
|
343
|
+
# Get all entity IDs
|
|
344
|
+
all_entity_ids = set(vector_scores.keys()) | set(graph_scores.keys())
|
|
345
|
+
|
|
346
|
+
# Combine scores
|
|
347
|
+
combined: Dict[str, float] = {}
|
|
348
|
+
|
|
349
|
+
for entity_id in all_entity_ids:
|
|
350
|
+
v_score = vector_scores.get(entity_id, 0.0)
|
|
351
|
+
g_score = graph_scores.get(entity_id, 0.0)
|
|
352
|
+
|
|
353
|
+
# Weighted combination
|
|
354
|
+
combined[entity_id] = v_score * norm_vector_weight + g_score * norm_graph_weight
|
|
355
|
+
|
|
356
|
+
return combined
|
|
357
|
+
|
|
358
|
+
async def search_with_expansion(
|
|
359
|
+
self,
|
|
360
|
+
query_embedding: List[float],
|
|
361
|
+
config: Optional[HybridSearchConfig] = None,
|
|
362
|
+
include_paths: bool = False,
|
|
363
|
+
) -> Tuple[List[Tuple[Entity, float]], Optional[List[Path]]]:
|
|
364
|
+
"""
|
|
365
|
+
Search with result expansion and optional path tracking
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
query_embedding: Query vector
|
|
369
|
+
config: Search configuration
|
|
370
|
+
include_paths: Whether to include paths to results
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Tuple of (results, paths) where paths is None if not requested
|
|
374
|
+
"""
|
|
375
|
+
if config is None:
|
|
376
|
+
config = HybridSearchConfig()
|
|
377
|
+
|
|
378
|
+
# Perform search
|
|
379
|
+
results = await self.search(query_embedding, config)
|
|
380
|
+
|
|
381
|
+
paths = None
|
|
382
|
+
if include_paths and config.expand_results:
|
|
383
|
+
# Find paths from top vector results to expanded results
|
|
384
|
+
paths = await self._find_result_paths(results, config)
|
|
385
|
+
|
|
386
|
+
return results, paths
|
|
387
|
+
|
|
388
|
+
async def _find_result_paths(
|
|
389
|
+
self, results: List[Tuple[Entity, float]], config: HybridSearchConfig
|
|
390
|
+
) -> List[Path]:
|
|
391
|
+
"""
|
|
392
|
+
Find paths between top results
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
results: Search results
|
|
396
|
+
config: Search configuration
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
List of paths connecting results
|
|
400
|
+
"""
|
|
401
|
+
if len(results) < 2:
|
|
402
|
+
return []
|
|
403
|
+
|
|
404
|
+
paths = []
|
|
405
|
+
|
|
406
|
+
# Find paths between top results
|
|
407
|
+
for i in range(min(3, len(results))):
|
|
408
|
+
source_id = results[i][0].id
|
|
409
|
+
|
|
410
|
+
for j in range(i + 1, min(i + 4, len(results))):
|
|
411
|
+
target_id = results[j][0].id
|
|
412
|
+
|
|
413
|
+
# Find paths between these entities
|
|
414
|
+
found_paths = await self.graph_store.find_paths(
|
|
415
|
+
source_entity_id=source_id,
|
|
416
|
+
target_entity_id=target_id,
|
|
417
|
+
max_depth=config.max_graph_depth,
|
|
418
|
+
max_paths=2,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
paths.extend(found_paths)
|
|
422
|
+
|
|
423
|
+
return paths
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Result Reranking Framework
|
|
3
|
+
|
|
4
|
+
Pluggable reranking strategies for improving search result relevance.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import List, Dict, Optional, Tuple
|
|
9
|
+
from enum import Enum
|
|
10
|
+
|
|
11
|
+
from aiecs.domain.knowledge_graph.models.entity import Entity
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ScoreCombinationMethod(str, Enum):
|
|
15
|
+
"""Methods for combining scores from multiple reranking strategies"""
|
|
16
|
+
|
|
17
|
+
WEIGHTED_AVERAGE = "weighted_average"
|
|
18
|
+
RRF = "rrf" # Reciprocal Rank Fusion
|
|
19
|
+
MAX = "max"
|
|
20
|
+
MIN = "min"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RerankerStrategy(ABC):
|
|
24
|
+
"""
|
|
25
|
+
Abstract base class for reranking strategies
|
|
26
|
+
|
|
27
|
+
Each strategy computes a relevance score for entities given a query.
|
|
28
|
+
Strategies can be combined using different combination methods.
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
```python
|
|
32
|
+
class TextSimilarityReranker(RerankerStrategy):
|
|
33
|
+
async def score(
|
|
34
|
+
self,
|
|
35
|
+
query: str,
|
|
36
|
+
entities: List[Entity]
|
|
37
|
+
) -> List[float]:
|
|
38
|
+
# Compute BM25 scores
|
|
39
|
+
return scores
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def name(self) -> str:
|
|
46
|
+
"""Strategy name for identification"""
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
async def score(self, query: str, entities: List[Entity], **kwargs) -> List[float]:
|
|
50
|
+
"""
|
|
51
|
+
Compute relevance scores for entities
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
query: Query text or context
|
|
55
|
+
entities: List of entities to score
|
|
56
|
+
**kwargs: Strategy-specific parameters
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of scores (one per entity), same order as entities
|
|
60
|
+
Scores should be in range [0.0, 1.0] for best results
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def normalize_scores(scores: List[float], method: str = "min_max") -> List[float]:
|
|
65
|
+
"""
|
|
66
|
+
Normalize scores to [0.0, 1.0] range
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
scores: Raw scores to normalize
|
|
70
|
+
method: Normalization method ("min_max", "z_score", "softmax")
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Normalized scores in [0.0, 1.0] range
|
|
74
|
+
"""
|
|
75
|
+
if not scores:
|
|
76
|
+
return []
|
|
77
|
+
|
|
78
|
+
if method == "min_max":
|
|
79
|
+
min_score = min(scores)
|
|
80
|
+
max_score = max(scores)
|
|
81
|
+
if max_score == min_score:
|
|
82
|
+
return [1.0] * len(scores)
|
|
83
|
+
return [(s - min_score) / (max_score - min_score) for s in scores]
|
|
84
|
+
|
|
85
|
+
elif method == "z_score":
|
|
86
|
+
import statistics
|
|
87
|
+
|
|
88
|
+
if len(scores) < 2:
|
|
89
|
+
return [1.0] * len(scores)
|
|
90
|
+
mean = statistics.mean(scores)
|
|
91
|
+
stdev = statistics.stdev(scores) if len(scores) > 1 else 1.0
|
|
92
|
+
if stdev == 0:
|
|
93
|
+
return [1.0] * len(scores)
|
|
94
|
+
# Normalize to [0, 1] using sigmoid
|
|
95
|
+
normalized = [(s - mean) / stdev for s in scores]
|
|
96
|
+
import math
|
|
97
|
+
|
|
98
|
+
return [1 / (1 + math.exp(-n)) for n in normalized]
|
|
99
|
+
|
|
100
|
+
elif method == "softmax":
|
|
101
|
+
import math
|
|
102
|
+
|
|
103
|
+
# Shift to avoid overflow
|
|
104
|
+
max_score = max(scores)
|
|
105
|
+
exp_scores = [math.exp(s - max_score) for s in scores]
|
|
106
|
+
sum_exp = sum(exp_scores)
|
|
107
|
+
if sum_exp == 0:
|
|
108
|
+
return [1.0 / len(scores)] * len(scores)
|
|
109
|
+
return [e / sum_exp for e in exp_scores]
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Unknown normalization method: {method}")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def combine_scores(
|
|
116
|
+
score_dicts: List[Dict[str, float]],
|
|
117
|
+
method: ScoreCombinationMethod = ScoreCombinationMethod.WEIGHTED_AVERAGE,
|
|
118
|
+
weights: Optional[Dict[str, float]] = None,
|
|
119
|
+
) -> Dict[str, float]:
|
|
120
|
+
"""
|
|
121
|
+
Combine scores from multiple strategies
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
score_dicts: List of {entity_id: score} dictionaries from each strategy
|
|
125
|
+
method: Combination method
|
|
126
|
+
weights: Optional weights for each strategy (for weighted_average)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Combined scores as {entity_id: combined_score}
|
|
130
|
+
"""
|
|
131
|
+
if not score_dicts:
|
|
132
|
+
return {}
|
|
133
|
+
|
|
134
|
+
# Collect all entity IDs
|
|
135
|
+
all_entity_ids = set()
|
|
136
|
+
for score_dict in score_dicts:
|
|
137
|
+
all_entity_ids.update(score_dict.keys())
|
|
138
|
+
|
|
139
|
+
if method == ScoreCombinationMethod.WEIGHTED_AVERAGE:
|
|
140
|
+
if weights is None:
|
|
141
|
+
# Equal weights
|
|
142
|
+
weight = 1.0 / len(score_dicts)
|
|
143
|
+
weights = {f"strategy_{i}": weight for i in range(len(score_dicts))}
|
|
144
|
+
|
|
145
|
+
combined = {}
|
|
146
|
+
for entity_id in all_entity_ids:
|
|
147
|
+
weighted_sum = 0.0
|
|
148
|
+
total_weight = 0.0
|
|
149
|
+
for i, score_dict in enumerate(score_dicts):
|
|
150
|
+
strategy_name = f"strategy_{i}"
|
|
151
|
+
weight = weights.get(strategy_name, 1.0 / len(score_dicts))
|
|
152
|
+
score = score_dict.get(entity_id, 0.0)
|
|
153
|
+
weighted_sum += weight * score
|
|
154
|
+
total_weight += weight
|
|
155
|
+
combined[entity_id] = weighted_sum / total_weight if total_weight > 0 else 0.0
|
|
156
|
+
return combined
|
|
157
|
+
|
|
158
|
+
elif method == ScoreCombinationMethod.RRF:
|
|
159
|
+
# Reciprocal Rank Fusion
|
|
160
|
+
k = 60 # RRF constant
|
|
161
|
+
combined = {}
|
|
162
|
+
for entity_id in all_entity_ids:
|
|
163
|
+
rrf_score = 0.0
|
|
164
|
+
for score_dict in score_dicts:
|
|
165
|
+
if entity_id in score_dict:
|
|
166
|
+
# Get rank (1-indexed, higher score = lower rank)
|
|
167
|
+
scores = sorted(score_dict.values(), reverse=True)
|
|
168
|
+
rank = scores.index(score_dict[entity_id]) + 1
|
|
169
|
+
rrf_score += 1.0 / (k + rank)
|
|
170
|
+
combined[entity_id] = rrf_score
|
|
171
|
+
return combined
|
|
172
|
+
|
|
173
|
+
elif method == ScoreCombinationMethod.MAX:
|
|
174
|
+
combined = {}
|
|
175
|
+
for entity_id in all_entity_ids:
|
|
176
|
+
combined[entity_id] = max(score_dict.get(entity_id, 0.0) for score_dict in score_dicts)
|
|
177
|
+
return combined
|
|
178
|
+
|
|
179
|
+
elif method == ScoreCombinationMethod.MIN:
|
|
180
|
+
combined = {}
|
|
181
|
+
for entity_id in all_entity_ids:
|
|
182
|
+
combined[entity_id] = min(score_dict.get(entity_id, 1.0) for score_dict in score_dicts)
|
|
183
|
+
return combined
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError(f"Unknown combination method: {method}")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ResultReranker:
|
|
190
|
+
"""
|
|
191
|
+
Result Reranker orchestrator
|
|
192
|
+
|
|
193
|
+
Combines multiple reranking strategies to improve search result relevance.
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```python
|
|
197
|
+
# Create strategies
|
|
198
|
+
text_reranker = TextSimilarityReranker()
|
|
199
|
+
semantic_reranker = SemanticReranker()
|
|
200
|
+
|
|
201
|
+
# Create reranker
|
|
202
|
+
reranker = ResultReranker(
|
|
203
|
+
strategies=[text_reranker, semantic_reranker],
|
|
204
|
+
combination_method=ScoreCombinationMethod.WEIGHTED_AVERAGE,
|
|
205
|
+
weights={"text": 0.6, "semantic": 0.4}
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Rerank results
|
|
209
|
+
reranked = await reranker.rerank(
|
|
210
|
+
query="machine learning",
|
|
211
|
+
entities=search_results
|
|
212
|
+
)
|
|
213
|
+
```
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
strategies: List[RerankerStrategy],
|
|
219
|
+
combination_method: ScoreCombinationMethod = ScoreCombinationMethod.WEIGHTED_AVERAGE,
|
|
220
|
+
weights: Optional[Dict[str, float]] = None,
|
|
221
|
+
normalize_scores: bool = True,
|
|
222
|
+
normalization_method: str = "min_max",
|
|
223
|
+
):
|
|
224
|
+
"""
|
|
225
|
+
Initialize ResultReranker
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
strategies: List of reranking strategies
|
|
229
|
+
combination_method: Method for combining scores
|
|
230
|
+
weights: Optional weights for strategies (for weighted_average)
|
|
231
|
+
normalize_scores: Whether to normalize scores before combining
|
|
232
|
+
normalization_method: Normalization method ("min_max", "z_score", "softmax")
|
|
233
|
+
"""
|
|
234
|
+
if not strategies:
|
|
235
|
+
raise ValueError("At least one strategy is required")
|
|
236
|
+
|
|
237
|
+
self.strategies = strategies
|
|
238
|
+
self.combination_method = combination_method
|
|
239
|
+
self.weights = weights or {}
|
|
240
|
+
self.normalize_scores = normalize_scores
|
|
241
|
+
self.normalization_method = normalization_method
|
|
242
|
+
|
|
243
|
+
async def rerank(
|
|
244
|
+
self,
|
|
245
|
+
query: str,
|
|
246
|
+
entities: List[Entity],
|
|
247
|
+
top_k: Optional[int] = None,
|
|
248
|
+
**kwargs,
|
|
249
|
+
) -> List[Tuple[Entity, float]]:
|
|
250
|
+
"""
|
|
251
|
+
Rerank entities using all strategies
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
query: Query text or context
|
|
255
|
+
entities: List of entities to rerank
|
|
256
|
+
top_k: Optional limit on number of results
|
|
257
|
+
**kwargs: Additional parameters passed to strategies
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of (entity, combined_score) tuples, sorted by score descending
|
|
261
|
+
"""
|
|
262
|
+
if not entities:
|
|
263
|
+
return []
|
|
264
|
+
|
|
265
|
+
# Get scores from each strategy
|
|
266
|
+
strategy_scores = []
|
|
267
|
+
for strategy in self.strategies:
|
|
268
|
+
scores = await strategy.score(query, entities, **kwargs)
|
|
269
|
+
|
|
270
|
+
# Normalize if requested
|
|
271
|
+
if self.normalize_scores:
|
|
272
|
+
scores = normalize_scores(scores, self.normalization_method)
|
|
273
|
+
|
|
274
|
+
# Convert to entity_id -> score dictionary
|
|
275
|
+
score_dict = {entity.id: score for entity, score in zip(entities, scores)}
|
|
276
|
+
strategy_scores.append(score_dict)
|
|
277
|
+
|
|
278
|
+
# Combine scores
|
|
279
|
+
combined_scores = combine_scores(
|
|
280
|
+
strategy_scores,
|
|
281
|
+
method=self.combination_method,
|
|
282
|
+
weights=self.weights,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Create (entity, score) tuples
|
|
286
|
+
reranked = [(entity, combined_scores.get(entity.id, 0.0)) for entity in entities]
|
|
287
|
+
|
|
288
|
+
# Sort by score descending
|
|
289
|
+
reranked.sort(key=lambda x: x[1], reverse=True)
|
|
290
|
+
|
|
291
|
+
# Apply top_k limit
|
|
292
|
+
if top_k is not None:
|
|
293
|
+
reranked = reranked[:top_k]
|
|
294
|
+
|
|
295
|
+
return reranked
|