PraisonAI 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonai/__init__.py +54 -0
- praisonai/__main__.py +15 -0
- praisonai/acp/__init__.py +54 -0
- praisonai/acp/config.py +159 -0
- praisonai/acp/server.py +587 -0
- praisonai/acp/session.py +219 -0
- praisonai/adapters/__init__.py +50 -0
- praisonai/adapters/readers.py +395 -0
- praisonai/adapters/rerankers.py +315 -0
- praisonai/adapters/retrievers.py +394 -0
- praisonai/adapters/vector_stores.py +409 -0
- praisonai/agent_scheduler.py +337 -0
- praisonai/agents_generator.py +903 -0
- praisonai/api/call.py +292 -0
- praisonai/auto.py +1197 -0
- praisonai/capabilities/__init__.py +275 -0
- praisonai/capabilities/a2a.py +140 -0
- praisonai/capabilities/assistants.py +283 -0
- praisonai/capabilities/audio.py +320 -0
- praisonai/capabilities/batches.py +469 -0
- praisonai/capabilities/completions.py +336 -0
- praisonai/capabilities/container_files.py +155 -0
- praisonai/capabilities/containers.py +93 -0
- praisonai/capabilities/embeddings.py +158 -0
- praisonai/capabilities/files.py +467 -0
- praisonai/capabilities/fine_tuning.py +293 -0
- praisonai/capabilities/guardrails.py +182 -0
- praisonai/capabilities/images.py +330 -0
- praisonai/capabilities/mcp.py +190 -0
- praisonai/capabilities/messages.py +270 -0
- praisonai/capabilities/moderations.py +154 -0
- praisonai/capabilities/ocr.py +217 -0
- praisonai/capabilities/passthrough.py +204 -0
- praisonai/capabilities/rag.py +207 -0
- praisonai/capabilities/realtime.py +160 -0
- praisonai/capabilities/rerank.py +165 -0
- praisonai/capabilities/responses.py +266 -0
- praisonai/capabilities/search.py +109 -0
- praisonai/capabilities/skills.py +133 -0
- praisonai/capabilities/vector_store_files.py +334 -0
- praisonai/capabilities/vector_stores.py +304 -0
- praisonai/capabilities/videos.py +141 -0
- praisonai/chainlit_ui.py +304 -0
- praisonai/chat/__init__.py +106 -0
- praisonai/chat/app.py +125 -0
- praisonai/cli/__init__.py +26 -0
- praisonai/cli/app.py +213 -0
- praisonai/cli/commands/__init__.py +75 -0
- praisonai/cli/commands/acp.py +70 -0
- praisonai/cli/commands/completion.py +333 -0
- praisonai/cli/commands/config.py +166 -0
- praisonai/cli/commands/debug.py +142 -0
- praisonai/cli/commands/diag.py +55 -0
- praisonai/cli/commands/doctor.py +166 -0
- praisonai/cli/commands/environment.py +179 -0
- praisonai/cli/commands/lsp.py +112 -0
- praisonai/cli/commands/mcp.py +210 -0
- praisonai/cli/commands/profile.py +457 -0
- praisonai/cli/commands/run.py +228 -0
- praisonai/cli/commands/schedule.py +150 -0
- praisonai/cli/commands/serve.py +97 -0
- praisonai/cli/commands/session.py +212 -0
- praisonai/cli/commands/traces.py +145 -0
- praisonai/cli/commands/version.py +101 -0
- praisonai/cli/configuration/__init__.py +18 -0
- praisonai/cli/configuration/loader.py +353 -0
- praisonai/cli/configuration/paths.py +114 -0
- praisonai/cli/configuration/schema.py +164 -0
- praisonai/cli/features/__init__.py +268 -0
- praisonai/cli/features/acp.py +236 -0
- praisonai/cli/features/action_orchestrator.py +546 -0
- praisonai/cli/features/agent_scheduler.py +773 -0
- praisonai/cli/features/agent_tools.py +474 -0
- praisonai/cli/features/agents.py +375 -0
- praisonai/cli/features/at_mentions.py +471 -0
- praisonai/cli/features/auto_memory.py +182 -0
- praisonai/cli/features/autonomy_mode.py +490 -0
- praisonai/cli/features/background.py +356 -0
- praisonai/cli/features/base.py +168 -0
- praisonai/cli/features/capabilities.py +1326 -0
- praisonai/cli/features/checkpoints.py +338 -0
- praisonai/cli/features/code_intelligence.py +652 -0
- praisonai/cli/features/compaction.py +294 -0
- praisonai/cli/features/compare.py +534 -0
- praisonai/cli/features/cost_tracker.py +514 -0
- praisonai/cli/features/debug.py +810 -0
- praisonai/cli/features/deploy.py +517 -0
- praisonai/cli/features/diag.py +289 -0
- praisonai/cli/features/doctor/__init__.py +63 -0
- praisonai/cli/features/doctor/checks/__init__.py +24 -0
- praisonai/cli/features/doctor/checks/acp_checks.py +240 -0
- praisonai/cli/features/doctor/checks/config_checks.py +366 -0
- praisonai/cli/features/doctor/checks/db_checks.py +366 -0
- praisonai/cli/features/doctor/checks/env_checks.py +543 -0
- praisonai/cli/features/doctor/checks/lsp_checks.py +199 -0
- praisonai/cli/features/doctor/checks/mcp_checks.py +349 -0
- praisonai/cli/features/doctor/checks/memory_checks.py +268 -0
- praisonai/cli/features/doctor/checks/network_checks.py +251 -0
- praisonai/cli/features/doctor/checks/obs_checks.py +328 -0
- praisonai/cli/features/doctor/checks/performance_checks.py +235 -0
- praisonai/cli/features/doctor/checks/permissions_checks.py +259 -0
- praisonai/cli/features/doctor/checks/selftest_checks.py +322 -0
- praisonai/cli/features/doctor/checks/serve_checks.py +426 -0
- praisonai/cli/features/doctor/checks/skills_checks.py +231 -0
- praisonai/cli/features/doctor/checks/tools_checks.py +371 -0
- praisonai/cli/features/doctor/engine.py +266 -0
- praisonai/cli/features/doctor/formatters.py +310 -0
- praisonai/cli/features/doctor/handler.py +397 -0
- praisonai/cli/features/doctor/models.py +264 -0
- praisonai/cli/features/doctor/registry.py +239 -0
- praisonai/cli/features/endpoints.py +1019 -0
- praisonai/cli/features/eval.py +560 -0
- praisonai/cli/features/external_agents.py +231 -0
- praisonai/cli/features/fast_context.py +410 -0
- praisonai/cli/features/flow_display.py +566 -0
- praisonai/cli/features/git_integration.py +651 -0
- praisonai/cli/features/guardrail.py +171 -0
- praisonai/cli/features/handoff.py +185 -0
- praisonai/cli/features/hooks.py +583 -0
- praisonai/cli/features/image.py +384 -0
- praisonai/cli/features/interactive_runtime.py +585 -0
- praisonai/cli/features/interactive_tools.py +380 -0
- praisonai/cli/features/interactive_tui.py +603 -0
- praisonai/cli/features/jobs.py +632 -0
- praisonai/cli/features/knowledge.py +531 -0
- praisonai/cli/features/lite.py +244 -0
- praisonai/cli/features/lsp_cli.py +225 -0
- praisonai/cli/features/mcp.py +169 -0
- praisonai/cli/features/message_queue.py +587 -0
- praisonai/cli/features/metrics.py +211 -0
- praisonai/cli/features/n8n.py +673 -0
- praisonai/cli/features/observability.py +293 -0
- praisonai/cli/features/ollama.py +361 -0
- praisonai/cli/features/output_style.py +273 -0
- praisonai/cli/features/package.py +631 -0
- praisonai/cli/features/performance.py +308 -0
- praisonai/cli/features/persistence.py +636 -0
- praisonai/cli/features/profile.py +226 -0
- praisonai/cli/features/profiler/__init__.py +81 -0
- praisonai/cli/features/profiler/core.py +558 -0
- praisonai/cli/features/profiler/optimizations.py +652 -0
- praisonai/cli/features/profiler/suite.py +386 -0
- praisonai/cli/features/profiling.py +350 -0
- praisonai/cli/features/queue/__init__.py +73 -0
- praisonai/cli/features/queue/manager.py +395 -0
- praisonai/cli/features/queue/models.py +286 -0
- praisonai/cli/features/queue/persistence.py +564 -0
- praisonai/cli/features/queue/scheduler.py +484 -0
- praisonai/cli/features/queue/worker.py +372 -0
- praisonai/cli/features/recipe.py +1723 -0
- praisonai/cli/features/recipes.py +449 -0
- praisonai/cli/features/registry.py +229 -0
- praisonai/cli/features/repo_map.py +860 -0
- praisonai/cli/features/router.py +466 -0
- praisonai/cli/features/sandbox_executor.py +515 -0
- praisonai/cli/features/serve.py +829 -0
- praisonai/cli/features/session.py +222 -0
- praisonai/cli/features/skills.py +856 -0
- praisonai/cli/features/slash_commands.py +650 -0
- praisonai/cli/features/telemetry.py +179 -0
- praisonai/cli/features/templates.py +1384 -0
- praisonai/cli/features/thinking.py +305 -0
- praisonai/cli/features/todo.py +334 -0
- praisonai/cli/features/tools.py +680 -0
- praisonai/cli/features/tui/__init__.py +83 -0
- praisonai/cli/features/tui/app.py +580 -0
- praisonai/cli/features/tui/cli.py +566 -0
- praisonai/cli/features/tui/debug.py +511 -0
- praisonai/cli/features/tui/events.py +99 -0
- praisonai/cli/features/tui/mock_provider.py +328 -0
- praisonai/cli/features/tui/orchestrator.py +652 -0
- praisonai/cli/features/tui/screens/__init__.py +50 -0
- praisonai/cli/features/tui/screens/main.py +245 -0
- praisonai/cli/features/tui/screens/queue.py +174 -0
- praisonai/cli/features/tui/screens/session.py +124 -0
- praisonai/cli/features/tui/screens/settings.py +148 -0
- praisonai/cli/features/tui/widgets/__init__.py +56 -0
- praisonai/cli/features/tui/widgets/chat.py +261 -0
- praisonai/cli/features/tui/widgets/composer.py +224 -0
- praisonai/cli/features/tui/widgets/queue_panel.py +200 -0
- praisonai/cli/features/tui/widgets/status.py +167 -0
- praisonai/cli/features/tui/widgets/tool_panel.py +248 -0
- praisonai/cli/features/workflow.py +720 -0
- praisonai/cli/legacy.py +236 -0
- praisonai/cli/main.py +5559 -0
- praisonai/cli/schedule_cli.py +54 -0
- praisonai/cli/state/__init__.py +31 -0
- praisonai/cli/state/identifiers.py +161 -0
- praisonai/cli/state/sessions.py +313 -0
- praisonai/code/__init__.py +93 -0
- praisonai/code/agent_tools.py +344 -0
- praisonai/code/diff/__init__.py +21 -0
- praisonai/code/diff/diff_strategy.py +432 -0
- praisonai/code/tools/__init__.py +27 -0
- praisonai/code/tools/apply_diff.py +221 -0
- praisonai/code/tools/execute_command.py +275 -0
- praisonai/code/tools/list_files.py +274 -0
- praisonai/code/tools/read_file.py +206 -0
- praisonai/code/tools/search_replace.py +248 -0
- praisonai/code/tools/write_file.py +217 -0
- praisonai/code/utils/__init__.py +46 -0
- praisonai/code/utils/file_utils.py +307 -0
- praisonai/code/utils/ignore_utils.py +308 -0
- praisonai/code/utils/text_utils.py +276 -0
- praisonai/db/__init__.py +64 -0
- praisonai/db/adapter.py +531 -0
- praisonai/deploy/__init__.py +62 -0
- praisonai/deploy/api.py +231 -0
- praisonai/deploy/docker.py +454 -0
- praisonai/deploy/doctor.py +367 -0
- praisonai/deploy/main.py +327 -0
- praisonai/deploy/models.py +179 -0
- praisonai/deploy/providers/__init__.py +33 -0
- praisonai/deploy/providers/aws.py +331 -0
- praisonai/deploy/providers/azure.py +358 -0
- praisonai/deploy/providers/base.py +101 -0
- praisonai/deploy/providers/gcp.py +314 -0
- praisonai/deploy/schema.py +208 -0
- praisonai/deploy.py +185 -0
- praisonai/endpoints/__init__.py +53 -0
- praisonai/endpoints/a2u_server.py +410 -0
- praisonai/endpoints/discovery.py +165 -0
- praisonai/endpoints/providers/__init__.py +28 -0
- praisonai/endpoints/providers/a2a.py +253 -0
- praisonai/endpoints/providers/a2u.py +208 -0
- praisonai/endpoints/providers/agents_api.py +171 -0
- praisonai/endpoints/providers/base.py +231 -0
- praisonai/endpoints/providers/mcp.py +263 -0
- praisonai/endpoints/providers/recipe.py +206 -0
- praisonai/endpoints/providers/tools_mcp.py +150 -0
- praisonai/endpoints/registry.py +131 -0
- praisonai/endpoints/server.py +161 -0
- praisonai/inbuilt_tools/__init__.py +24 -0
- praisonai/inbuilt_tools/autogen_tools.py +117 -0
- praisonai/inc/__init__.py +2 -0
- praisonai/inc/config.py +96 -0
- praisonai/inc/models.py +155 -0
- praisonai/integrations/__init__.py +56 -0
- praisonai/integrations/base.py +303 -0
- praisonai/integrations/claude_code.py +270 -0
- praisonai/integrations/codex_cli.py +255 -0
- praisonai/integrations/cursor_cli.py +195 -0
- praisonai/integrations/gemini_cli.py +222 -0
- praisonai/jobs/__init__.py +67 -0
- praisonai/jobs/executor.py +425 -0
- praisonai/jobs/models.py +230 -0
- praisonai/jobs/router.py +314 -0
- praisonai/jobs/server.py +186 -0
- praisonai/jobs/store.py +203 -0
- praisonai/llm/__init__.py +66 -0
- praisonai/llm/registry.py +382 -0
- praisonai/mcp_server/__init__.py +152 -0
- praisonai/mcp_server/adapters/__init__.py +74 -0
- praisonai/mcp_server/adapters/agents.py +128 -0
- praisonai/mcp_server/adapters/capabilities.py +168 -0
- praisonai/mcp_server/adapters/cli_tools.py +568 -0
- praisonai/mcp_server/adapters/extended_capabilities.py +462 -0
- praisonai/mcp_server/adapters/knowledge.py +93 -0
- praisonai/mcp_server/adapters/memory.py +104 -0
- praisonai/mcp_server/adapters/prompts.py +306 -0
- praisonai/mcp_server/adapters/resources.py +124 -0
- praisonai/mcp_server/adapters/tools_bridge.py +280 -0
- praisonai/mcp_server/auth/__init__.py +48 -0
- praisonai/mcp_server/auth/api_key.py +291 -0
- praisonai/mcp_server/auth/oauth.py +460 -0
- praisonai/mcp_server/auth/oidc.py +289 -0
- praisonai/mcp_server/auth/scopes.py +260 -0
- praisonai/mcp_server/cli.py +852 -0
- praisonai/mcp_server/elicitation.py +445 -0
- praisonai/mcp_server/icons.py +302 -0
- praisonai/mcp_server/recipe_adapter.py +573 -0
- praisonai/mcp_server/recipe_cli.py +824 -0
- praisonai/mcp_server/registry.py +703 -0
- praisonai/mcp_server/sampling.py +422 -0
- praisonai/mcp_server/server.py +490 -0
- praisonai/mcp_server/tasks.py +443 -0
- praisonai/mcp_server/transports/__init__.py +18 -0
- praisonai/mcp_server/transports/http_stream.py +376 -0
- praisonai/mcp_server/transports/stdio.py +132 -0
- praisonai/persistence/__init__.py +84 -0
- praisonai/persistence/config.py +238 -0
- praisonai/persistence/conversation/__init__.py +25 -0
- praisonai/persistence/conversation/async_mysql.py +427 -0
- praisonai/persistence/conversation/async_postgres.py +410 -0
- praisonai/persistence/conversation/async_sqlite.py +371 -0
- praisonai/persistence/conversation/base.py +151 -0
- praisonai/persistence/conversation/json_store.py +250 -0
- praisonai/persistence/conversation/mysql.py +387 -0
- praisonai/persistence/conversation/postgres.py +401 -0
- praisonai/persistence/conversation/singlestore.py +240 -0
- praisonai/persistence/conversation/sqlite.py +341 -0
- praisonai/persistence/conversation/supabase.py +203 -0
- praisonai/persistence/conversation/surrealdb.py +287 -0
- praisonai/persistence/factory.py +301 -0
- praisonai/persistence/hooks/__init__.py +18 -0
- praisonai/persistence/hooks/agent_hooks.py +297 -0
- praisonai/persistence/knowledge/__init__.py +26 -0
- praisonai/persistence/knowledge/base.py +144 -0
- praisonai/persistence/knowledge/cassandra.py +232 -0
- praisonai/persistence/knowledge/chroma.py +295 -0
- praisonai/persistence/knowledge/clickhouse.py +242 -0
- praisonai/persistence/knowledge/cosmosdb_vector.py +438 -0
- praisonai/persistence/knowledge/couchbase.py +286 -0
- praisonai/persistence/knowledge/lancedb.py +216 -0
- praisonai/persistence/knowledge/langchain_adapter.py +291 -0
- praisonai/persistence/knowledge/lightrag_adapter.py +212 -0
- praisonai/persistence/knowledge/llamaindex_adapter.py +256 -0
- praisonai/persistence/knowledge/milvus.py +277 -0
- praisonai/persistence/knowledge/mongodb_vector.py +306 -0
- praisonai/persistence/knowledge/pgvector.py +335 -0
- praisonai/persistence/knowledge/pinecone.py +253 -0
- praisonai/persistence/knowledge/qdrant.py +301 -0
- praisonai/persistence/knowledge/redis_vector.py +291 -0
- praisonai/persistence/knowledge/singlestore_vector.py +299 -0
- praisonai/persistence/knowledge/surrealdb_vector.py +309 -0
- praisonai/persistence/knowledge/upstash_vector.py +266 -0
- praisonai/persistence/knowledge/weaviate.py +223 -0
- praisonai/persistence/migrations/__init__.py +10 -0
- praisonai/persistence/migrations/manager.py +251 -0
- praisonai/persistence/orchestrator.py +406 -0
- praisonai/persistence/state/__init__.py +21 -0
- praisonai/persistence/state/async_mongodb.py +200 -0
- praisonai/persistence/state/base.py +107 -0
- praisonai/persistence/state/dynamodb.py +226 -0
- praisonai/persistence/state/firestore.py +175 -0
- praisonai/persistence/state/gcs.py +155 -0
- praisonai/persistence/state/memory.py +245 -0
- praisonai/persistence/state/mongodb.py +158 -0
- praisonai/persistence/state/redis.py +190 -0
- praisonai/persistence/state/upstash.py +144 -0
- praisonai/persistence/tests/__init__.py +3 -0
- praisonai/persistence/tests/test_all_backends.py +633 -0
- praisonai/profiler.py +1214 -0
- praisonai/recipe/__init__.py +134 -0
- praisonai/recipe/bridge.py +278 -0
- praisonai/recipe/core.py +893 -0
- praisonai/recipe/exceptions.py +54 -0
- praisonai/recipe/history.py +402 -0
- praisonai/recipe/models.py +266 -0
- praisonai/recipe/operations.py +440 -0
- praisonai/recipe/policy.py +422 -0
- praisonai/recipe/registry.py +849 -0
- praisonai/recipe/runtime.py +214 -0
- praisonai/recipe/security.py +711 -0
- praisonai/recipe/serve.py +859 -0
- praisonai/recipe/server.py +613 -0
- praisonai/scheduler/__init__.py +45 -0
- praisonai/scheduler/agent_scheduler.py +552 -0
- praisonai/scheduler/base.py +124 -0
- praisonai/scheduler/daemon_manager.py +225 -0
- praisonai/scheduler/state_manager.py +155 -0
- praisonai/scheduler/yaml_loader.py +193 -0
- praisonai/scheduler.py +194 -0
- praisonai/setup/__init__.py +1 -0
- praisonai/setup/build.py +21 -0
- praisonai/setup/post_install.py +23 -0
- praisonai/setup/setup_conda_env.py +25 -0
- praisonai/setup.py +16 -0
- praisonai/templates/__init__.py +116 -0
- praisonai/templates/cache.py +364 -0
- praisonai/templates/dependency_checker.py +358 -0
- praisonai/templates/discovery.py +391 -0
- praisonai/templates/loader.py +564 -0
- praisonai/templates/registry.py +511 -0
- praisonai/templates/resolver.py +206 -0
- praisonai/templates/security.py +327 -0
- praisonai/templates/tool_override.py +498 -0
- praisonai/templates/tools_doctor.py +256 -0
- praisonai/test.py +105 -0
- praisonai/train.py +562 -0
- praisonai/train_vision.py +306 -0
- praisonai/ui/agents.py +824 -0
- praisonai/ui/callbacks.py +57 -0
- praisonai/ui/chainlit_compat.py +246 -0
- praisonai/ui/chat.py +532 -0
- praisonai/ui/code.py +717 -0
- praisonai/ui/colab.py +474 -0
- praisonai/ui/colab_chainlit.py +81 -0
- praisonai/ui/components/aicoder.py +284 -0
- praisonai/ui/context.py +283 -0
- praisonai/ui/database_config.py +56 -0
- praisonai/ui/db.py +294 -0
- praisonai/ui/realtime.py +488 -0
- praisonai/ui/realtimeclient/__init__.py +756 -0
- praisonai/ui/realtimeclient/tools.py +242 -0
- praisonai/ui/sql_alchemy.py +710 -0
- praisonai/upload_vision.py +140 -0
- praisonai/version.py +1 -0
- praisonai-3.0.0.dist-info/METADATA +3493 -0
- praisonai-3.0.0.dist-info/RECORD +393 -0
- praisonai-3.0.0.dist-info/WHEEL +5 -0
- praisonai-3.0.0.dist-info/entry_points.txt +4 -0
- praisonai-3.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reranker Implementations for PraisonAI.
|
|
3
|
+
|
|
4
|
+
Provides concrete implementations of RerankerProtocol:
|
|
5
|
+
- LLMReranker: Uses LLM to rerank documents
|
|
6
|
+
- CrossEncoderReranker: Uses sentence-transformers cross-encoder
|
|
7
|
+
- CohereReranker: Uses Cohere rerank API
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, List, Optional
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Lazy import flags
|
|
16
|
+
_SENTENCE_TRANSFORMERS_AVAILABLE = None
|
|
17
|
+
_COHERE_AVAILABLE = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _check_sentence_transformers():
|
|
21
|
+
"""Check if sentence-transformers is available."""
|
|
22
|
+
global _SENTENCE_TRANSFORMERS_AVAILABLE
|
|
23
|
+
if _SENTENCE_TRANSFORMERS_AVAILABLE is None:
|
|
24
|
+
import importlib.util
|
|
25
|
+
_SENTENCE_TRANSFORMERS_AVAILABLE = importlib.util.find_spec("sentence_transformers") is not None
|
|
26
|
+
return _SENTENCE_TRANSFORMERS_AVAILABLE
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _check_cohere():
|
|
30
|
+
"""Check if cohere is available."""
|
|
31
|
+
global _COHERE_AVAILABLE
|
|
32
|
+
if _COHERE_AVAILABLE is None:
|
|
33
|
+
import importlib.util
|
|
34
|
+
_COHERE_AVAILABLE = importlib.util.find_spec("cohere") is not None
|
|
35
|
+
return _COHERE_AVAILABLE
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LLMReranker:
|
|
39
|
+
"""
|
|
40
|
+
LLM-based reranker.
|
|
41
|
+
|
|
42
|
+
Uses an LLM to score document relevance to a query.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
name: str = "llm"
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
llm: Optional[Any] = None,
|
|
50
|
+
model: Optional[str] = None,
|
|
51
|
+
batch_size: int = 5,
|
|
52
|
+
**kwargs
|
|
53
|
+
):
|
|
54
|
+
self.llm = llm
|
|
55
|
+
self.model = model or "gpt-4o-mini"
|
|
56
|
+
self.batch_size = batch_size
|
|
57
|
+
self._client = None
|
|
58
|
+
|
|
59
|
+
def _get_client(self):
|
|
60
|
+
"""Get LLM client - returns self.llm if provided, otherwise None (use litellm)."""
|
|
61
|
+
if self.llm:
|
|
62
|
+
return self.llm
|
|
63
|
+
return None # Will use litellm directly
|
|
64
|
+
|
|
65
|
+
def rerank(
|
|
66
|
+
self,
|
|
67
|
+
query: str,
|
|
68
|
+
documents: List[str],
|
|
69
|
+
top_k: Optional[int] = None,
|
|
70
|
+
**kwargs
|
|
71
|
+
) -> List[Any]:
|
|
72
|
+
"""Rerank documents using LLM scoring."""
|
|
73
|
+
from praisonaiagents.knowledge.rerankers import RerankResult
|
|
74
|
+
|
|
75
|
+
results = []
|
|
76
|
+
|
|
77
|
+
# Process in batches
|
|
78
|
+
for i in range(0, len(documents), self.batch_size):
|
|
79
|
+
batch = documents[i:i + self.batch_size]
|
|
80
|
+
batch_results = self._score_batch(query, batch, start_idx=i)
|
|
81
|
+
results.extend(batch_results)
|
|
82
|
+
|
|
83
|
+
# Sort by score
|
|
84
|
+
results.sort(key=lambda x: x.score, reverse=True)
|
|
85
|
+
|
|
86
|
+
if top_k is not None:
|
|
87
|
+
results = results[:top_k]
|
|
88
|
+
|
|
89
|
+
return results
|
|
90
|
+
|
|
91
|
+
async def arerank(
|
|
92
|
+
self,
|
|
93
|
+
query: str,
|
|
94
|
+
documents: List[str],
|
|
95
|
+
top_k: Optional[int] = None,
|
|
96
|
+
**kwargs
|
|
97
|
+
) -> List[Any]:
|
|
98
|
+
"""Async version (just calls sync)."""
|
|
99
|
+
return self.rerank(query, documents, top_k, **kwargs)
|
|
100
|
+
|
|
101
|
+
def _score_batch(self, query: str, documents: List[str], start_idx: int = 0) -> List[Any]:
|
|
102
|
+
"""Score a batch of documents."""
|
|
103
|
+
from praisonaiagents.knowledge.rerankers import RerankResult
|
|
104
|
+
|
|
105
|
+
results = []
|
|
106
|
+
|
|
107
|
+
for i, doc in enumerate(documents):
|
|
108
|
+
score = self._score_document(query, doc)
|
|
109
|
+
results.append(RerankResult(
|
|
110
|
+
text=doc,
|
|
111
|
+
score=score,
|
|
112
|
+
original_index=start_idx + i
|
|
113
|
+
))
|
|
114
|
+
|
|
115
|
+
return results
|
|
116
|
+
|
|
117
|
+
def _score_document(self, query: str, document: str) -> float:
|
|
118
|
+
"""Score a single document's relevance to query using litellm."""
|
|
119
|
+
try:
|
|
120
|
+
client = self._get_client()
|
|
121
|
+
|
|
122
|
+
prompt = f"""Rate the relevance of this document to the query on a scale of 0 to 10.
|
|
123
|
+
Return ONLY a number between 0 and 10.
|
|
124
|
+
|
|
125
|
+
Query: {query}
|
|
126
|
+
|
|
127
|
+
Document: {document[:1000]}
|
|
128
|
+
|
|
129
|
+
Relevance score (0-10):"""
|
|
130
|
+
|
|
131
|
+
# Check if client has a callable 'chat' method (PraisonAI Agent)
|
|
132
|
+
if client is not None and hasattr(client, 'chat') and callable(getattr(client, 'chat', None)):
|
|
133
|
+
# PraisonAI Agent or similar with callable chat method
|
|
134
|
+
response = client.chat(prompt)
|
|
135
|
+
else:
|
|
136
|
+
# Use litellm for unified multi-provider support
|
|
137
|
+
import litellm
|
|
138
|
+
|
|
139
|
+
# litellm handles model-specific quirks automatically
|
|
140
|
+
completion = litellm.completion(
|
|
141
|
+
model=self.model,
|
|
142
|
+
messages=[{"role": "user", "content": prompt}],
|
|
143
|
+
max_tokens=10
|
|
144
|
+
)
|
|
145
|
+
response = completion.choices[0].message.content
|
|
146
|
+
|
|
147
|
+
# Parse score
|
|
148
|
+
if response:
|
|
149
|
+
try:
|
|
150
|
+
score = float(response.strip().split()[0])
|
|
151
|
+
return min(max(score / 10.0, 0.0), 1.0) # Normalize to 0-1
|
|
152
|
+
except (ValueError, IndexError):
|
|
153
|
+
pass
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.warning(f"LLM scoring failed: {e}")
|
|
156
|
+
|
|
157
|
+
return 0.5 # Default score
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CrossEncoderReranker:
|
|
161
|
+
"""
|
|
162
|
+
Cross-encoder reranker using sentence-transformers.
|
|
163
|
+
|
|
164
|
+
Uses a cross-encoder model for accurate relevance scoring.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
name: str = "cross_encoder"
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
172
|
+
**kwargs
|
|
173
|
+
):
|
|
174
|
+
if not _check_sentence_transformers():
|
|
175
|
+
raise ImportError(
|
|
176
|
+
"sentence-transformers is required for CrossEncoderReranker. "
|
|
177
|
+
"Install with: pip install sentence-transformers"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
from sentence_transformers import CrossEncoder
|
|
181
|
+
self.model = CrossEncoder(model_name)
|
|
182
|
+
|
|
183
|
+
def rerank(
|
|
184
|
+
self,
|
|
185
|
+
query: str,
|
|
186
|
+
documents: List[str],
|
|
187
|
+
top_k: Optional[int] = None,
|
|
188
|
+
**kwargs
|
|
189
|
+
) -> List[Any]:
|
|
190
|
+
"""Rerank documents using cross-encoder."""
|
|
191
|
+
from praisonaiagents.knowledge.rerankers import RerankResult
|
|
192
|
+
|
|
193
|
+
# Create query-document pairs
|
|
194
|
+
pairs = [[query, doc] for doc in documents]
|
|
195
|
+
|
|
196
|
+
# Get scores
|
|
197
|
+
scores = self.model.predict(pairs)
|
|
198
|
+
|
|
199
|
+
# Create results
|
|
200
|
+
results = []
|
|
201
|
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
|
202
|
+
results.append(RerankResult(
|
|
203
|
+
text=doc,
|
|
204
|
+
score=float(score),
|
|
205
|
+
original_index=i
|
|
206
|
+
))
|
|
207
|
+
|
|
208
|
+
# Sort by score
|
|
209
|
+
results.sort(key=lambda x: x.score, reverse=True)
|
|
210
|
+
|
|
211
|
+
if top_k is not None:
|
|
212
|
+
results = results[:top_k]
|
|
213
|
+
|
|
214
|
+
return results
|
|
215
|
+
|
|
216
|
+
async def arerank(
|
|
217
|
+
self,
|
|
218
|
+
query: str,
|
|
219
|
+
documents: List[str],
|
|
220
|
+
top_k: Optional[int] = None,
|
|
221
|
+
**kwargs
|
|
222
|
+
) -> List[Any]:
|
|
223
|
+
"""Async version (just calls sync)."""
|
|
224
|
+
return self.rerank(query, documents, top_k, **kwargs)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class CohereReranker:
|
|
228
|
+
"""
|
|
229
|
+
Cohere reranker using Cohere's rerank API.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
name: str = "cohere"
|
|
233
|
+
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
api_key: Optional[str] = None,
|
|
237
|
+
model: str = "rerank-english-v3.0",
|
|
238
|
+
**kwargs
|
|
239
|
+
):
|
|
240
|
+
if not _check_cohere():
|
|
241
|
+
raise ImportError(
|
|
242
|
+
"cohere is required for CohereReranker. "
|
|
243
|
+
"Install with: pip install cohere"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
import os
|
|
247
|
+
import cohere
|
|
248
|
+
|
|
249
|
+
self.api_key = api_key or os.environ.get("COHERE_API_KEY")
|
|
250
|
+
if not self.api_key:
|
|
251
|
+
raise ValueError("COHERE_API_KEY environment variable or api_key parameter required")
|
|
252
|
+
|
|
253
|
+
self.model = model
|
|
254
|
+
self._client = cohere.Client(self.api_key)
|
|
255
|
+
|
|
256
|
+
def rerank(
|
|
257
|
+
self,
|
|
258
|
+
query: str,
|
|
259
|
+
documents: List[str],
|
|
260
|
+
top_k: Optional[int] = None,
|
|
261
|
+
**kwargs
|
|
262
|
+
) -> List[Any]:
|
|
263
|
+
"""Rerank documents using Cohere API."""
|
|
264
|
+
from praisonaiagents.knowledge.rerankers import RerankResult
|
|
265
|
+
|
|
266
|
+
response = self._client.rerank(
|
|
267
|
+
query=query,
|
|
268
|
+
documents=documents,
|
|
269
|
+
top_n=top_k or len(documents),
|
|
270
|
+
model=self.model
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
results = []
|
|
274
|
+
for r in response.results:
|
|
275
|
+
results.append(RerankResult(
|
|
276
|
+
text=documents[r.index],
|
|
277
|
+
score=r.relevance_score,
|
|
278
|
+
original_index=r.index
|
|
279
|
+
))
|
|
280
|
+
|
|
281
|
+
return results
|
|
282
|
+
|
|
283
|
+
async def arerank(
|
|
284
|
+
self,
|
|
285
|
+
query: str,
|
|
286
|
+
documents: List[str],
|
|
287
|
+
top_k: Optional[int] = None,
|
|
288
|
+
**kwargs
|
|
289
|
+
) -> List[Any]:
|
|
290
|
+
"""Async version (just calls sync)."""
|
|
291
|
+
return self.rerank(query, documents, top_k, **kwargs)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def register_default_rerankers():
|
|
295
|
+
"""Register all default rerankers with the registry."""
|
|
296
|
+
from praisonaiagents.knowledge.rerankers import get_reranker_registry
|
|
297
|
+
|
|
298
|
+
registry = get_reranker_registry()
|
|
299
|
+
|
|
300
|
+
# LLM reranker is always available
|
|
301
|
+
registry.register("llm", LLMReranker)
|
|
302
|
+
|
|
303
|
+
# Cross-encoder if available
|
|
304
|
+
if _check_sentence_transformers():
|
|
305
|
+
registry.register("cross_encoder", CrossEncoderReranker)
|
|
306
|
+
|
|
307
|
+
# Cohere if available
|
|
308
|
+
if _check_cohere():
|
|
309
|
+
registry.register("cohere", CohereReranker)
|
|
310
|
+
|
|
311
|
+
logger.debug("Registered default rerankers")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# Auto-register on import
|
|
315
|
+
register_default_rerankers()
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Retriever Implementations for PraisonAI.
|
|
3
|
+
|
|
4
|
+
Provides concrete implementations of RetrieverProtocol:
|
|
5
|
+
- BasicRetriever: Simple vector similarity retrieval
|
|
6
|
+
- FusionRetriever: Multi-query with RRF fusion
|
|
7
|
+
- RecursiveRetriever: Depth-limited sub-retrieval
|
|
8
|
+
- AutoMergeRetriever: Merges adjacent chunks
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BasicRetriever:
|
|
18
|
+
"""
|
|
19
|
+
Basic vector similarity retriever.
|
|
20
|
+
|
|
21
|
+
Simply queries the vector store and returns results.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name: str = "basic"
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
vector_store: Any,
|
|
29
|
+
embedding_fn: Optional[Callable[[str], List[float]]] = None,
|
|
30
|
+
top_k: int = 10,
|
|
31
|
+
**kwargs
|
|
32
|
+
):
|
|
33
|
+
self.vector_store = vector_store
|
|
34
|
+
self.embedding_fn = embedding_fn
|
|
35
|
+
self.top_k = top_k
|
|
36
|
+
|
|
37
|
+
def retrieve(
|
|
38
|
+
self,
|
|
39
|
+
query: str,
|
|
40
|
+
top_k: Optional[int] = None,
|
|
41
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
42
|
+
**kwargs
|
|
43
|
+
) -> List[Any]:
|
|
44
|
+
"""Retrieve documents by vector similarity."""
|
|
45
|
+
from praisonaiagents.knowledge.retrieval import RetrievalResult
|
|
46
|
+
|
|
47
|
+
k = top_k or self.top_k
|
|
48
|
+
|
|
49
|
+
# Get query embedding
|
|
50
|
+
if self.embedding_fn:
|
|
51
|
+
query_embedding = self.embedding_fn(query)
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError("embedding_fn is required for BasicRetriever")
|
|
54
|
+
|
|
55
|
+
# Query vector store
|
|
56
|
+
results = self.vector_store.query(
|
|
57
|
+
embedding=query_embedding,
|
|
58
|
+
top_k=k,
|
|
59
|
+
filter=filter
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Convert to RetrievalResult
|
|
63
|
+
retrieval_results = []
|
|
64
|
+
for r in results:
|
|
65
|
+
retrieval_results.append(RetrievalResult(
|
|
66
|
+
text=r.text if hasattr(r, 'text') else str(r),
|
|
67
|
+
score=r.score if hasattr(r, 'score') else 1.0,
|
|
68
|
+
metadata=r.metadata if hasattr(r, 'metadata') else {},
|
|
69
|
+
doc_id=r.id if hasattr(r, 'id') else None,
|
|
70
|
+
chunk_index=r.metadata.get('chunk_index') if hasattr(r, 'metadata') else None
|
|
71
|
+
))
|
|
72
|
+
|
|
73
|
+
return retrieval_results
|
|
74
|
+
|
|
75
|
+
async def aretrieve(
|
|
76
|
+
self,
|
|
77
|
+
query: str,
|
|
78
|
+
top_k: Optional[int] = None,
|
|
79
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
80
|
+
**kwargs
|
|
81
|
+
) -> List[Any]:
|
|
82
|
+
"""Async version (just calls sync)."""
|
|
83
|
+
return self.retrieve(query, top_k, filter, **kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class FusionRetriever:
|
|
87
|
+
"""
|
|
88
|
+
Fusion retriever using Reciprocal Rank Fusion.
|
|
89
|
+
|
|
90
|
+
Generates multiple query variations and fuses results.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
name: str = "fusion"
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
vector_store: Any,
|
|
98
|
+
embedding_fn: Optional[Callable[[str], List[float]]] = None,
|
|
99
|
+
llm: Optional[Any] = None,
|
|
100
|
+
num_queries: int = 3,
|
|
101
|
+
top_k: int = 10,
|
|
102
|
+
rrf_k: int = 60,
|
|
103
|
+
**kwargs
|
|
104
|
+
):
|
|
105
|
+
self.vector_store = vector_store
|
|
106
|
+
self.embedding_fn = embedding_fn
|
|
107
|
+
self.llm = llm
|
|
108
|
+
self.num_queries = num_queries
|
|
109
|
+
self.top_k = top_k
|
|
110
|
+
self.rrf_k = rrf_k
|
|
111
|
+
|
|
112
|
+
def retrieve(
|
|
113
|
+
self,
|
|
114
|
+
query: str,
|
|
115
|
+
top_k: Optional[int] = None,
|
|
116
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
117
|
+
**kwargs
|
|
118
|
+
) -> List[Any]:
|
|
119
|
+
"""Retrieve with query fusion."""
|
|
120
|
+
from praisonaiagents.knowledge.retrieval import RetrievalResult, reciprocal_rank_fusion
|
|
121
|
+
|
|
122
|
+
k = top_k or self.top_k
|
|
123
|
+
|
|
124
|
+
# Generate query variations
|
|
125
|
+
queries = self._generate_query_variations(query)
|
|
126
|
+
|
|
127
|
+
# Retrieve for each query
|
|
128
|
+
all_results = []
|
|
129
|
+
for q in queries:
|
|
130
|
+
if self.embedding_fn:
|
|
131
|
+
q_embedding = self.embedding_fn(q)
|
|
132
|
+
results = self.vector_store.query(
|
|
133
|
+
embedding=q_embedding,
|
|
134
|
+
top_k=k,
|
|
135
|
+
filter=filter
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Convert to RetrievalResult
|
|
139
|
+
query_results = []
|
|
140
|
+
for r in results:
|
|
141
|
+
query_results.append(RetrievalResult(
|
|
142
|
+
text=r.text if hasattr(r, 'text') else str(r),
|
|
143
|
+
score=r.score if hasattr(r, 'score') else 1.0,
|
|
144
|
+
metadata=r.metadata if hasattr(r, 'metadata') else {},
|
|
145
|
+
doc_id=r.id if hasattr(r, 'id') else None
|
|
146
|
+
))
|
|
147
|
+
all_results.append(query_results)
|
|
148
|
+
|
|
149
|
+
# Fuse results using RRF
|
|
150
|
+
fused = reciprocal_rank_fusion(all_results, k=self.rrf_k)
|
|
151
|
+
|
|
152
|
+
return fused[:k]
|
|
153
|
+
|
|
154
|
+
async def aretrieve(
|
|
155
|
+
self,
|
|
156
|
+
query: str,
|
|
157
|
+
top_k: Optional[int] = None,
|
|
158
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
159
|
+
**kwargs
|
|
160
|
+
) -> List[Any]:
|
|
161
|
+
"""Async version (just calls sync)."""
|
|
162
|
+
return self.retrieve(query, top_k, filter, **kwargs)
|
|
163
|
+
|
|
164
|
+
def _generate_query_variations(self, query: str) -> List[str]:
|
|
165
|
+
"""Generate query variations."""
|
|
166
|
+
queries = [query]
|
|
167
|
+
|
|
168
|
+
# Try LLM-based variation if available
|
|
169
|
+
if self.llm and len(queries) < self.num_queries:
|
|
170
|
+
try:
|
|
171
|
+
prompt = f"""Generate {self.num_queries - 1} alternative phrasings of this search query.
|
|
172
|
+
Return only the queries, one per line.
|
|
173
|
+
|
|
174
|
+
Original query: {query}"""
|
|
175
|
+
response = self.llm.chat(prompt)
|
|
176
|
+
if response:
|
|
177
|
+
for line in response.strip().split("\n"):
|
|
178
|
+
line = line.strip()
|
|
179
|
+
if line and line != query:
|
|
180
|
+
queries.append(line)
|
|
181
|
+
if len(queries) >= self.num_queries:
|
|
182
|
+
break
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.warning(f"LLM query variation failed: {e}")
|
|
185
|
+
|
|
186
|
+
# Fallback: simple variations
|
|
187
|
+
if len(queries) < self.num_queries:
|
|
188
|
+
# Add question form
|
|
189
|
+
if not query.endswith("?"):
|
|
190
|
+
queries.append(f"What is {query}?")
|
|
191
|
+
# Add keyword form
|
|
192
|
+
keywords = " ".join(query.split()[:5])
|
|
193
|
+
if keywords != query:
|
|
194
|
+
queries.append(keywords)
|
|
195
|
+
|
|
196
|
+
return queries[:self.num_queries]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class RecursiveRetriever:
|
|
200
|
+
"""
|
|
201
|
+
Recursive retriever with depth-limited sub-retrieval.
|
|
202
|
+
|
|
203
|
+
Retrieves, expands queries based on results, and retrieves again.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
name: str = "recursive"
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
vector_store: Any,
|
|
211
|
+
embedding_fn: Optional[Callable[[str], List[float]]] = None,
|
|
212
|
+
llm: Optional[Any] = None,
|
|
213
|
+
max_depth: int = 2,
|
|
214
|
+
top_k: int = 10,
|
|
215
|
+
**kwargs
|
|
216
|
+
):
|
|
217
|
+
self.vector_store = vector_store
|
|
218
|
+
self.embedding_fn = embedding_fn
|
|
219
|
+
self.llm = llm
|
|
220
|
+
self.max_depth = max_depth
|
|
221
|
+
self.top_k = top_k
|
|
222
|
+
|
|
223
|
+
def retrieve(
|
|
224
|
+
self,
|
|
225
|
+
query: str,
|
|
226
|
+
top_k: Optional[int] = None,
|
|
227
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
228
|
+
**kwargs
|
|
229
|
+
) -> List[Any]:
|
|
230
|
+
"""Retrieve with recursive expansion."""
|
|
231
|
+
from praisonaiagents.knowledge.retrieval import RetrievalResult
|
|
232
|
+
|
|
233
|
+
k = top_k or self.top_k
|
|
234
|
+
seen_ids = set()
|
|
235
|
+
all_results = []
|
|
236
|
+
|
|
237
|
+
current_queries = [query]
|
|
238
|
+
|
|
239
|
+
for depth in range(self.max_depth):
|
|
240
|
+
next_queries = []
|
|
241
|
+
|
|
242
|
+
for q in current_queries:
|
|
243
|
+
if self.embedding_fn:
|
|
244
|
+
q_embedding = self.embedding_fn(q)
|
|
245
|
+
results = self.vector_store.query(
|
|
246
|
+
embedding=q_embedding,
|
|
247
|
+
top_k=k,
|
|
248
|
+
filter=filter
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
for r in results:
|
|
252
|
+
r_id = r.id if hasattr(r, 'id') else r.text[:50]
|
|
253
|
+
if r_id not in seen_ids:
|
|
254
|
+
seen_ids.add(r_id)
|
|
255
|
+
all_results.append(RetrievalResult(
|
|
256
|
+
text=r.text if hasattr(r, 'text') else str(r),
|
|
257
|
+
score=r.score if hasattr(r, 'score') else 1.0,
|
|
258
|
+
metadata=r.metadata if hasattr(r, 'metadata') else {},
|
|
259
|
+
doc_id=r.id if hasattr(r, 'id') else None
|
|
260
|
+
))
|
|
261
|
+
|
|
262
|
+
# Generate follow-up query from result
|
|
263
|
+
if depth < self.max_depth - 1 and self.llm:
|
|
264
|
+
follow_up = self._generate_follow_up(query, r.text if hasattr(r, 'text') else str(r))
|
|
265
|
+
if follow_up:
|
|
266
|
+
next_queries.append(follow_up)
|
|
267
|
+
|
|
268
|
+
current_queries = next_queries[:3] # Limit expansion
|
|
269
|
+
if not current_queries:
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
# Sort by score
|
|
273
|
+
all_results.sort(key=lambda x: x.score, reverse=True)
|
|
274
|
+
return all_results[:k]
|
|
275
|
+
|
|
276
|
+
async def aretrieve(
|
|
277
|
+
self,
|
|
278
|
+
query: str,
|
|
279
|
+
top_k: Optional[int] = None,
|
|
280
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
281
|
+
**kwargs
|
|
282
|
+
) -> List[Any]:
|
|
283
|
+
"""Async version (just calls sync)."""
|
|
284
|
+
return self.retrieve(query, top_k, filter, **kwargs)
|
|
285
|
+
|
|
286
|
+
def _generate_follow_up(self, original_query: str, context: str) -> Optional[str]:
|
|
287
|
+
"""Generate a follow-up query based on retrieved context."""
|
|
288
|
+
if not self.llm:
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
prompt = f"""Based on this search result, generate a follow-up query to find more relevant information.
|
|
293
|
+
Return only the query, nothing else.
|
|
294
|
+
|
|
295
|
+
Original query: {original_query}
|
|
296
|
+
Search result: {context[:500]}"""
|
|
297
|
+
response = self.llm.chat(prompt)
|
|
298
|
+
if response:
|
|
299
|
+
return response.strip()
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.warning(f"Follow-up query generation failed: {e}")
|
|
302
|
+
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class AutoMergeRetriever:
|
|
307
|
+
"""
|
|
308
|
+
Auto-merge retriever that combines adjacent chunks.
|
|
309
|
+
|
|
310
|
+
Retrieves chunks and merges those from the same document.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
name: str = "auto_merge"
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
vector_store: Any,
|
|
318
|
+
embedding_fn: Optional[Callable[[str], List[float]]] = None,
|
|
319
|
+
top_k: int = 10,
|
|
320
|
+
max_gap: int = 1,
|
|
321
|
+
**kwargs
|
|
322
|
+
):
|
|
323
|
+
self.vector_store = vector_store
|
|
324
|
+
self.embedding_fn = embedding_fn
|
|
325
|
+
self.top_k = top_k
|
|
326
|
+
self.max_gap = max_gap
|
|
327
|
+
|
|
328
|
+
def retrieve(
|
|
329
|
+
self,
|
|
330
|
+
query: str,
|
|
331
|
+
top_k: Optional[int] = None,
|
|
332
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
333
|
+
**kwargs
|
|
334
|
+
) -> List[Any]:
|
|
335
|
+
"""Retrieve and merge adjacent chunks."""
|
|
336
|
+
from praisonaiagents.knowledge.retrieval import RetrievalResult, merge_adjacent_chunks
|
|
337
|
+
|
|
338
|
+
k = top_k or self.top_k
|
|
339
|
+
|
|
340
|
+
# Get more results than needed for merging
|
|
341
|
+
fetch_k = k * 3
|
|
342
|
+
|
|
343
|
+
if self.embedding_fn:
|
|
344
|
+
q_embedding = self.embedding_fn(query)
|
|
345
|
+
results = self.vector_store.query(
|
|
346
|
+
embedding=q_embedding,
|
|
347
|
+
top_k=fetch_k,
|
|
348
|
+
filter=filter
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Convert to RetrievalResult
|
|
352
|
+
retrieval_results = []
|
|
353
|
+
for r in results:
|
|
354
|
+
retrieval_results.append(RetrievalResult(
|
|
355
|
+
text=r.text if hasattr(r, 'text') else str(r),
|
|
356
|
+
score=r.score if hasattr(r, 'score') else 1.0,
|
|
357
|
+
metadata=r.metadata if hasattr(r, 'metadata') else {},
|
|
358
|
+
doc_id=r.metadata.get('doc_id') if hasattr(r, 'metadata') else None,
|
|
359
|
+
chunk_index=r.metadata.get('chunk_index') if hasattr(r, 'metadata') else None
|
|
360
|
+
))
|
|
361
|
+
|
|
362
|
+
# Merge adjacent chunks
|
|
363
|
+
merged = merge_adjacent_chunks(retrieval_results, max_gap=self.max_gap)
|
|
364
|
+
return merged[:k]
|
|
365
|
+
|
|
366
|
+
return []
|
|
367
|
+
|
|
368
|
+
async def aretrieve(
|
|
369
|
+
self,
|
|
370
|
+
query: str,
|
|
371
|
+
top_k: Optional[int] = None,
|
|
372
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
373
|
+
**kwargs
|
|
374
|
+
) -> List[Any]:
|
|
375
|
+
"""Async version (just calls sync)."""
|
|
376
|
+
return self.retrieve(query, top_k, filter, **kwargs)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def register_default_retrievers():
|
|
380
|
+
"""Register all default retrievers with the registry."""
|
|
381
|
+
from praisonaiagents.knowledge.retrieval import get_retriever_registry
|
|
382
|
+
|
|
383
|
+
registry = get_retriever_registry()
|
|
384
|
+
|
|
385
|
+
registry.register("basic", BasicRetriever)
|
|
386
|
+
registry.register("fusion", FusionRetriever)
|
|
387
|
+
registry.register("recursive", RecursiveRetriever)
|
|
388
|
+
registry.register("auto_merge", AutoMergeRetriever)
|
|
389
|
+
|
|
390
|
+
logger.debug("Registered default retrievers")
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
# Auto-register on import
|
|
394
|
+
register_default_retrievers()
|