agno 2.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/__init__.py +8 -0
- agno/agent/__init__.py +51 -0
- agno/agent/agent.py +10405 -0
- agno/api/__init__.py +0 -0
- agno/api/agent.py +28 -0
- agno/api/api.py +40 -0
- agno/api/evals.py +22 -0
- agno/api/os.py +17 -0
- agno/api/routes.py +13 -0
- agno/api/schemas/__init__.py +9 -0
- agno/api/schemas/agent.py +16 -0
- agno/api/schemas/evals.py +16 -0
- agno/api/schemas/os.py +14 -0
- agno/api/schemas/response.py +6 -0
- agno/api/schemas/team.py +16 -0
- agno/api/schemas/utils.py +21 -0
- agno/api/schemas/workflows.py +16 -0
- agno/api/settings.py +53 -0
- agno/api/team.py +30 -0
- agno/api/workflow.py +28 -0
- agno/cloud/aws/base.py +214 -0
- agno/cloud/aws/s3/__init__.py +2 -0
- agno/cloud/aws/s3/api_client.py +43 -0
- agno/cloud/aws/s3/bucket.py +195 -0
- agno/cloud/aws/s3/object.py +57 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/__init__.py +24 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +598 -0
- agno/db/dynamo/__init__.py +3 -0
- agno/db/dynamo/dynamo.py +2042 -0
- agno/db/dynamo/schemas.py +314 -0
- agno/db/dynamo/utils.py +743 -0
- agno/db/firestore/__init__.py +3 -0
- agno/db/firestore/firestore.py +1795 -0
- agno/db/firestore/schemas.py +140 -0
- agno/db/firestore/utils.py +376 -0
- agno/db/gcs_json/__init__.py +3 -0
- agno/db/gcs_json/gcs_json_db.py +1335 -0
- agno/db/gcs_json/utils.py +228 -0
- agno/db/in_memory/__init__.py +3 -0
- agno/db/in_memory/in_memory_db.py +1160 -0
- agno/db/in_memory/utils.py +230 -0
- agno/db/json/__init__.py +3 -0
- agno/db/json/json_db.py +1328 -0
- agno/db/json/utils.py +230 -0
- agno/db/migrations/__init__.py +0 -0
- agno/db/migrations/v1_to_v2.py +635 -0
- agno/db/mongo/__init__.py +17 -0
- agno/db/mongo/async_mongo.py +2026 -0
- agno/db/mongo/mongo.py +1982 -0
- agno/db/mongo/schemas.py +87 -0
- agno/db/mongo/utils.py +259 -0
- agno/db/mysql/__init__.py +3 -0
- agno/db/mysql/mysql.py +2308 -0
- agno/db/mysql/schemas.py +138 -0
- agno/db/mysql/utils.py +355 -0
- agno/db/postgres/__init__.py +4 -0
- agno/db/postgres/async_postgres.py +1927 -0
- agno/db/postgres/postgres.py +2260 -0
- agno/db/postgres/schemas.py +139 -0
- agno/db/postgres/utils.py +442 -0
- agno/db/redis/__init__.py +3 -0
- agno/db/redis/redis.py +1660 -0
- agno/db/redis/schemas.py +123 -0
- agno/db/redis/utils.py +346 -0
- agno/db/schemas/__init__.py +4 -0
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +33 -0
- agno/db/schemas/knowledge.py +40 -0
- agno/db/schemas/memory.py +46 -0
- agno/db/schemas/metrics.py +0 -0
- agno/db/singlestore/__init__.py +3 -0
- agno/db/singlestore/schemas.py +130 -0
- agno/db/singlestore/singlestore.py +2272 -0
- agno/db/singlestore/utils.py +384 -0
- agno/db/sqlite/__init__.py +4 -0
- agno/db/sqlite/async_sqlite.py +2293 -0
- agno/db/sqlite/schemas.py +133 -0
- agno/db/sqlite/sqlite.py +2288 -0
- agno/db/sqlite/utils.py +431 -0
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1353 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +116 -0
- agno/debug.py +18 -0
- agno/eval/__init__.py +14 -0
- agno/eval/accuracy.py +834 -0
- agno/eval/performance.py +773 -0
- agno/eval/reliability.py +306 -0
- agno/eval/utils.py +119 -0
- agno/exceptions.py +161 -0
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/__init__.py +0 -0
- agno/integrations/discord/__init__.py +3 -0
- agno/integrations/discord/client.py +203 -0
- agno/knowledge/__init__.py +5 -0
- agno/knowledge/chunking/__init__.py +0 -0
- agno/knowledge/chunking/agentic.py +79 -0
- agno/knowledge/chunking/document.py +91 -0
- agno/knowledge/chunking/fixed.py +57 -0
- agno/knowledge/chunking/markdown.py +151 -0
- agno/knowledge/chunking/recursive.py +63 -0
- agno/knowledge/chunking/row.py +39 -0
- agno/knowledge/chunking/semantic.py +86 -0
- agno/knowledge/chunking/strategy.py +165 -0
- agno/knowledge/content.py +74 -0
- agno/knowledge/document/__init__.py +5 -0
- agno/knowledge/document/base.py +58 -0
- agno/knowledge/embedder/__init__.py +5 -0
- agno/knowledge/embedder/aws_bedrock.py +343 -0
- agno/knowledge/embedder/azure_openai.py +210 -0
- agno/knowledge/embedder/base.py +23 -0
- agno/knowledge/embedder/cohere.py +323 -0
- agno/knowledge/embedder/fastembed.py +62 -0
- agno/knowledge/embedder/fireworks.py +13 -0
- agno/knowledge/embedder/google.py +258 -0
- agno/knowledge/embedder/huggingface.py +94 -0
- agno/knowledge/embedder/jina.py +182 -0
- agno/knowledge/embedder/langdb.py +22 -0
- agno/knowledge/embedder/mistral.py +206 -0
- agno/knowledge/embedder/nebius.py +13 -0
- agno/knowledge/embedder/ollama.py +154 -0
- agno/knowledge/embedder/openai.py +195 -0
- agno/knowledge/embedder/sentence_transformer.py +63 -0
- agno/knowledge/embedder/together.py +13 -0
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +165 -0
- agno/knowledge/knowledge.py +1988 -0
- agno/knowledge/reader/__init__.py +7 -0
- agno/knowledge/reader/arxiv_reader.py +81 -0
- agno/knowledge/reader/base.py +95 -0
- agno/knowledge/reader/csv_reader.py +166 -0
- agno/knowledge/reader/docx_reader.py +82 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +292 -0
- agno/knowledge/reader/firecrawl_reader.py +201 -0
- agno/knowledge/reader/json_reader.py +87 -0
- agno/knowledge/reader/markdown_reader.py +137 -0
- agno/knowledge/reader/pdf_reader.py +431 -0
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +313 -0
- agno/knowledge/reader/s3_reader.py +89 -0
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +115 -0
- agno/knowledge/reader/web_search_reader.py +372 -0
- agno/knowledge/reader/website_reader.py +455 -0
- agno/knowledge/reader/wikipedia_reader.py +59 -0
- agno/knowledge/reader/youtube_reader.py +78 -0
- agno/knowledge/remote_content/__init__.py +0 -0
- agno/knowledge/remote_content/remote_content.py +88 -0
- agno/knowledge/reranker/__init__.py +3 -0
- agno/knowledge/reranker/base.py +14 -0
- agno/knowledge/reranker/cohere.py +64 -0
- agno/knowledge/reranker/infinity.py +195 -0
- agno/knowledge/reranker/sentence_transformer.py +54 -0
- agno/knowledge/types.py +39 -0
- agno/knowledge/utils.py +189 -0
- agno/media.py +462 -0
- agno/memory/__init__.py +3 -0
- agno/memory/manager.py +1327 -0
- agno/models/__init__.py +0 -0
- agno/models/aimlapi/__init__.py +5 -0
- agno/models/aimlapi/aimlapi.py +45 -0
- agno/models/anthropic/__init__.py +5 -0
- agno/models/anthropic/claude.py +757 -0
- agno/models/aws/__init__.py +15 -0
- agno/models/aws/bedrock.py +701 -0
- agno/models/aws/claude.py +378 -0
- agno/models/azure/__init__.py +18 -0
- agno/models/azure/ai_foundry.py +485 -0
- agno/models/azure/openai_chat.py +131 -0
- agno/models/base.py +2175 -0
- agno/models/cerebras/__init__.py +12 -0
- agno/models/cerebras/cerebras.py +501 -0
- agno/models/cerebras/cerebras_openai.py +112 -0
- agno/models/cohere/__init__.py +5 -0
- agno/models/cohere/chat.py +389 -0
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/__init__.py +5 -0
- agno/models/dashscope/dashscope.py +91 -0
- agno/models/deepinfra/__init__.py +5 -0
- agno/models/deepinfra/deepinfra.py +28 -0
- agno/models/deepseek/__init__.py +5 -0
- agno/models/deepseek/deepseek.py +61 -0
- agno/models/defaults.py +1 -0
- agno/models/fireworks/__init__.py +5 -0
- agno/models/fireworks/fireworks.py +26 -0
- agno/models/google/__init__.py +5 -0
- agno/models/google/gemini.py +1085 -0
- agno/models/groq/__init__.py +5 -0
- agno/models/groq/groq.py +556 -0
- agno/models/huggingface/__init__.py +5 -0
- agno/models/huggingface/huggingface.py +491 -0
- agno/models/ibm/__init__.py +5 -0
- agno/models/ibm/watsonx.py +422 -0
- agno/models/internlm/__init__.py +3 -0
- agno/models/internlm/internlm.py +26 -0
- agno/models/langdb/__init__.py +1 -0
- agno/models/langdb/langdb.py +48 -0
- agno/models/litellm/__init__.py +14 -0
- agno/models/litellm/chat.py +468 -0
- agno/models/litellm/litellm_openai.py +25 -0
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/lmstudio/__init__.py +5 -0
- agno/models/lmstudio/lmstudio.py +25 -0
- agno/models/message.py +434 -0
- agno/models/meta/__init__.py +12 -0
- agno/models/meta/llama.py +475 -0
- agno/models/meta/llama_openai.py +78 -0
- agno/models/metrics.py +120 -0
- agno/models/mistral/__init__.py +5 -0
- agno/models/mistral/mistral.py +432 -0
- agno/models/nebius/__init__.py +3 -0
- agno/models/nebius/nebius.py +54 -0
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/__init__.py +5 -0
- agno/models/nvidia/nvidia.py +28 -0
- agno/models/ollama/__init__.py +5 -0
- agno/models/ollama/chat.py +441 -0
- agno/models/openai/__init__.py +9 -0
- agno/models/openai/chat.py +883 -0
- agno/models/openai/like.py +27 -0
- agno/models/openai/responses.py +1050 -0
- agno/models/openrouter/__init__.py +5 -0
- agno/models/openrouter/openrouter.py +66 -0
- agno/models/perplexity/__init__.py +5 -0
- agno/models/perplexity/perplexity.py +187 -0
- agno/models/portkey/__init__.py +3 -0
- agno/models/portkey/portkey.py +81 -0
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +199 -0
- agno/models/sambanova/__init__.py +5 -0
- agno/models/sambanova/sambanova.py +28 -0
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/__init__.py +5 -0
- agno/models/together/together.py +25 -0
- agno/models/utils.py +266 -0
- agno/models/vercel/__init__.py +3 -0
- agno/models/vercel/v0.py +26 -0
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +70 -0
- agno/models/vllm/__init__.py +3 -0
- agno/models/vllm/vllm.py +78 -0
- agno/models/xai/__init__.py +3 -0
- agno/models/xai/xai.py +113 -0
- agno/os/__init__.py +3 -0
- agno/os/app.py +876 -0
- agno/os/auth.py +57 -0
- agno/os/config.py +104 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/__init__.py +3 -0
- agno/os/interfaces/agui/agui.py +47 -0
- agno/os/interfaces/agui/router.py +144 -0
- agno/os/interfaces/agui/utils.py +534 -0
- agno/os/interfaces/base.py +25 -0
- agno/os/interfaces/slack/__init__.py +3 -0
- agno/os/interfaces/slack/router.py +148 -0
- agno/os/interfaces/slack/security.py +30 -0
- agno/os/interfaces/slack/slack.py +47 -0
- agno/os/interfaces/whatsapp/__init__.py +3 -0
- agno/os/interfaces/whatsapp/router.py +211 -0
- agno/os/interfaces/whatsapp/security.py +53 -0
- agno/os/interfaces/whatsapp/whatsapp.py +36 -0
- agno/os/mcp.py +292 -0
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +1763 -0
- agno/os/routers/__init__.py +3 -0
- agno/os/routers/evals/__init__.py +3 -0
- agno/os/routers/evals/evals.py +430 -0
- agno/os/routers/evals/schemas.py +142 -0
- agno/os/routers/evals/utils.py +162 -0
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/__init__.py +3 -0
- agno/os/routers/knowledge/knowledge.py +997 -0
- agno/os/routers/knowledge/schemas.py +178 -0
- agno/os/routers/memory/__init__.py +3 -0
- agno/os/routers/memory/memory.py +515 -0
- agno/os/routers/memory/schemas.py +62 -0
- agno/os/routers/metrics/__init__.py +3 -0
- agno/os/routers/metrics/metrics.py +190 -0
- agno/os/routers/metrics/schemas.py +47 -0
- agno/os/routers/session/__init__.py +3 -0
- agno/os/routers/session/session.py +997 -0
- agno/os/schema.py +1055 -0
- agno/os/settings.py +43 -0
- agno/os/utils.py +630 -0
- agno/py.typed +0 -0
- agno/reasoning/__init__.py +0 -0
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +67 -0
- agno/reasoning/deepseek.py +63 -0
- agno/reasoning/default.py +97 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +71 -0
- agno/reasoning/helpers.py +63 -0
- agno/reasoning/ollama.py +67 -0
- agno/reasoning/openai.py +86 -0
- agno/reasoning/step.py +31 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +787 -0
- agno/run/base.py +229 -0
- agno/run/cancel.py +81 -0
- agno/run/messages.py +32 -0
- agno/run/team.py +753 -0
- agno/run/workflow.py +708 -0
- agno/session/__init__.py +10 -0
- agno/session/agent.py +295 -0
- agno/session/summary.py +265 -0
- agno/session/team.py +392 -0
- agno/session/workflow.py +205 -0
- agno/team/__init__.py +37 -0
- agno/team/team.py +8793 -0
- agno/tools/__init__.py +10 -0
- agno/tools/agentql.py +120 -0
- agno/tools/airflow.py +69 -0
- agno/tools/api.py +122 -0
- agno/tools/apify.py +314 -0
- agno/tools/arxiv.py +127 -0
- agno/tools/aws_lambda.py +53 -0
- agno/tools/aws_ses.py +66 -0
- agno/tools/baidusearch.py +89 -0
- agno/tools/bitbucket.py +292 -0
- agno/tools/brandfetch.py +213 -0
- agno/tools/bravesearch.py +106 -0
- agno/tools/brightdata.py +367 -0
- agno/tools/browserbase.py +209 -0
- agno/tools/calcom.py +255 -0
- agno/tools/calculator.py +151 -0
- agno/tools/cartesia.py +187 -0
- agno/tools/clickup.py +244 -0
- agno/tools/confluence.py +240 -0
- agno/tools/crawl4ai.py +158 -0
- agno/tools/csv_toolkit.py +185 -0
- agno/tools/dalle.py +110 -0
- agno/tools/daytona.py +475 -0
- agno/tools/decorator.py +262 -0
- agno/tools/desi_vocal.py +108 -0
- agno/tools/discord.py +161 -0
- agno/tools/docker.py +716 -0
- agno/tools/duckdb.py +379 -0
- agno/tools/duckduckgo.py +91 -0
- agno/tools/e2b.py +703 -0
- agno/tools/eleven_labs.py +196 -0
- agno/tools/email.py +67 -0
- agno/tools/evm.py +129 -0
- agno/tools/exa.py +396 -0
- agno/tools/fal.py +127 -0
- agno/tools/file.py +240 -0
- agno/tools/file_generation.py +350 -0
- agno/tools/financial_datasets.py +288 -0
- agno/tools/firecrawl.py +143 -0
- agno/tools/function.py +1187 -0
- agno/tools/giphy.py +93 -0
- agno/tools/github.py +1760 -0
- agno/tools/gmail.py +922 -0
- agno/tools/google_bigquery.py +117 -0
- agno/tools/google_drive.py +270 -0
- agno/tools/google_maps.py +253 -0
- agno/tools/googlecalendar.py +674 -0
- agno/tools/googlesearch.py +98 -0
- agno/tools/googlesheets.py +377 -0
- agno/tools/hackernews.py +77 -0
- agno/tools/jina.py +101 -0
- agno/tools/jira.py +170 -0
- agno/tools/knowledge.py +218 -0
- agno/tools/linear.py +426 -0
- agno/tools/linkup.py +58 -0
- agno/tools/local_file_system.py +90 -0
- agno/tools/lumalab.py +183 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +193 -0
- agno/tools/memori.py +339 -0
- agno/tools/memory.py +419 -0
- agno/tools/mlx_transcribe.py +139 -0
- agno/tools/models/__init__.py +0 -0
- agno/tools/models/azure_openai.py +190 -0
- agno/tools/models/gemini.py +203 -0
- agno/tools/models/groq.py +158 -0
- agno/tools/models/morph.py +186 -0
- agno/tools/models/nebius.py +124 -0
- agno/tools/models_labs.py +195 -0
- agno/tools/moviepy_video.py +349 -0
- agno/tools/neo4j.py +134 -0
- agno/tools/newspaper.py +46 -0
- agno/tools/newspaper4k.py +93 -0
- agno/tools/notion.py +204 -0
- agno/tools/openai.py +202 -0
- agno/tools/openbb.py +160 -0
- agno/tools/opencv.py +321 -0
- agno/tools/openweather.py +233 -0
- agno/tools/oxylabs.py +385 -0
- agno/tools/pandas.py +102 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +257 -0
- agno/tools/pubmed.py +188 -0
- agno/tools/python.py +205 -0
- agno/tools/reasoning.py +283 -0
- agno/tools/reddit.py +467 -0
- agno/tools/replicate.py +117 -0
- agno/tools/resend.py +62 -0
- agno/tools/scrapegraph.py +222 -0
- agno/tools/searxng.py +152 -0
- agno/tools/serpapi.py +116 -0
- agno/tools/serper.py +255 -0
- agno/tools/shell.py +53 -0
- agno/tools/slack.py +136 -0
- agno/tools/sleep.py +20 -0
- agno/tools/spider.py +116 -0
- agno/tools/sql.py +154 -0
- agno/tools/streamlit/__init__.py +0 -0
- agno/tools/streamlit/components.py +113 -0
- agno/tools/tavily.py +254 -0
- agno/tools/telegram.py +48 -0
- agno/tools/todoist.py +218 -0
- agno/tools/tool_registry.py +1 -0
- agno/tools/toolkit.py +146 -0
- agno/tools/trafilatura.py +388 -0
- agno/tools/trello.py +274 -0
- agno/tools/twilio.py +186 -0
- agno/tools/user_control_flow.py +78 -0
- agno/tools/valyu.py +228 -0
- agno/tools/visualization.py +467 -0
- agno/tools/webbrowser.py +28 -0
- agno/tools/webex.py +76 -0
- agno/tools/website.py +54 -0
- agno/tools/webtools.py +45 -0
- agno/tools/whatsapp.py +286 -0
- agno/tools/wikipedia.py +63 -0
- agno/tools/workflow.py +278 -0
- agno/tools/x.py +335 -0
- agno/tools/yfinance.py +257 -0
- agno/tools/youtube.py +184 -0
- agno/tools/zendesk.py +82 -0
- agno/tools/zep.py +454 -0
- agno/tools/zoom.py +382 -0
- agno/utils/__init__.py +0 -0
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +49 -0
- agno/utils/certs.py +27 -0
- agno/utils/code_execution.py +11 -0
- agno/utils/common.py +132 -0
- agno/utils/dttm.py +13 -0
- agno/utils/enum.py +22 -0
- agno/utils/env.py +11 -0
- agno/utils/events.py +696 -0
- agno/utils/format_str.py +16 -0
- agno/utils/functions.py +166 -0
- agno/utils/gemini.py +426 -0
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +74 -0
- agno/utils/json_schema.py +234 -0
- agno/utils/knowledge.py +36 -0
- agno/utils/location.py +19 -0
- agno/utils/log.py +255 -0
- agno/utils/mcp.py +214 -0
- agno/utils/media.py +352 -0
- agno/utils/merge_dict.py +41 -0
- agno/utils/message.py +118 -0
- agno/utils/models/__init__.py +0 -0
- agno/utils/models/ai_foundry.py +43 -0
- agno/utils/models/claude.py +358 -0
- agno/utils/models/cohere.py +87 -0
- agno/utils/models/llama.py +78 -0
- agno/utils/models/mistral.py +98 -0
- agno/utils/models/openai_responses.py +140 -0
- agno/utils/models/schema_utils.py +153 -0
- agno/utils/models/watsonx.py +41 -0
- agno/utils/openai.py +257 -0
- agno/utils/pickle.py +32 -0
- agno/utils/pprint.py +178 -0
- agno/utils/print_response/__init__.py +0 -0
- agno/utils/print_response/agent.py +842 -0
- agno/utils/print_response/team.py +1724 -0
- agno/utils/print_response/workflow.py +1668 -0
- agno/utils/prompts.py +111 -0
- agno/utils/reasoning.py +108 -0
- agno/utils/response.py +163 -0
- agno/utils/response_iterator.py +17 -0
- agno/utils/safe_formatter.py +24 -0
- agno/utils/serialize.py +32 -0
- agno/utils/shell.py +22 -0
- agno/utils/streamlit.py +487 -0
- agno/utils/string.py +231 -0
- agno/utils/team.py +139 -0
- agno/utils/timer.py +41 -0
- agno/utils/tools.py +102 -0
- agno/utils/web.py +23 -0
- agno/utils/whatsapp.py +305 -0
- agno/utils/yaml_io.py +25 -0
- agno/vectordb/__init__.py +3 -0
- agno/vectordb/base.py +127 -0
- agno/vectordb/cassandra/__init__.py +5 -0
- agno/vectordb/cassandra/cassandra.py +501 -0
- agno/vectordb/cassandra/extra_param_mixin.py +11 -0
- agno/vectordb/cassandra/index.py +13 -0
- agno/vectordb/chroma/__init__.py +5 -0
- agno/vectordb/chroma/chromadb.py +929 -0
- agno/vectordb/clickhouse/__init__.py +9 -0
- agno/vectordb/clickhouse/clickhousedb.py +835 -0
- agno/vectordb/clickhouse/index.py +9 -0
- agno/vectordb/couchbase/__init__.py +3 -0
- agno/vectordb/couchbase/couchbase.py +1442 -0
- agno/vectordb/distance.py +7 -0
- agno/vectordb/lancedb/__init__.py +6 -0
- agno/vectordb/lancedb/lance_db.py +995 -0
- agno/vectordb/langchaindb/__init__.py +5 -0
- agno/vectordb/langchaindb/langchaindb.py +163 -0
- agno/vectordb/lightrag/__init__.py +5 -0
- agno/vectordb/lightrag/lightrag.py +388 -0
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +166 -0
- agno/vectordb/milvus/__init__.py +4 -0
- agno/vectordb/milvus/milvus.py +1182 -0
- agno/vectordb/mongodb/__init__.py +9 -0
- agno/vectordb/mongodb/mongodb.py +1417 -0
- agno/vectordb/pgvector/__init__.py +12 -0
- agno/vectordb/pgvector/index.py +23 -0
- agno/vectordb/pgvector/pgvector.py +1462 -0
- agno/vectordb/pineconedb/__init__.py +5 -0
- agno/vectordb/pineconedb/pineconedb.py +747 -0
- agno/vectordb/qdrant/__init__.py +5 -0
- agno/vectordb/qdrant/qdrant.py +1134 -0
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/search.py +7 -0
- agno/vectordb/singlestore/__init__.py +10 -0
- agno/vectordb/singlestore/index.py +41 -0
- agno/vectordb/singlestore/singlestore.py +763 -0
- agno/vectordb/surrealdb/__init__.py +3 -0
- agno/vectordb/surrealdb/surrealdb.py +699 -0
- agno/vectordb/upstashdb/__init__.py +5 -0
- agno/vectordb/upstashdb/upstashdb.py +718 -0
- agno/vectordb/weaviate/__init__.py +8 -0
- agno/vectordb/weaviate/index.py +15 -0
- agno/vectordb/weaviate/weaviate.py +1005 -0
- agno/workflow/__init__.py +23 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +738 -0
- agno/workflow/loop.py +735 -0
- agno/workflow/parallel.py +824 -0
- agno/workflow/router.py +702 -0
- agno/workflow/step.py +1432 -0
- agno/workflow/steps.py +592 -0
- agno/workflow/types.py +520 -0
- agno/workflow/workflow.py +4321 -0
- agno-2.2.13.dist-info/METADATA +614 -0
- agno-2.2.13.dist-info/RECORD +575 -0
- agno-2.2.13.dist-info/WHEEL +5 -0
- agno-2.2.13.dist-info/licenses/LICENSE +201 -0
- agno-2.2.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1462 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from hashlib import md5
|
|
3
|
+
from math import sqrt
|
|
4
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
from agno.utils.string import generate_id
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from sqlalchemy import and_, not_, or_, update
|
|
10
|
+
from sqlalchemy.dialects import postgresql
|
|
11
|
+
from sqlalchemy.engine import Engine, create_engine
|
|
12
|
+
from sqlalchemy.inspection import inspect
|
|
13
|
+
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
|
14
|
+
from sqlalchemy.schema import Column, Index, MetaData, Table
|
|
15
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
16
|
+
from sqlalchemy.sql.expression import bindparam, desc, func, select, text
|
|
17
|
+
from sqlalchemy.types import DateTime, Integer, String
|
|
18
|
+
|
|
19
|
+
except ImportError:
|
|
20
|
+
raise ImportError("`sqlalchemy` not installed. Please install using `pip install sqlalchemy psycopg`")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from pgvector.sqlalchemy import Vector
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError("`pgvector` not installed. Please install using `pip install pgvector`")
|
|
26
|
+
|
|
27
|
+
from agno.filters import FilterExpr
|
|
28
|
+
from agno.knowledge.document import Document
|
|
29
|
+
from agno.knowledge.embedder import Embedder
|
|
30
|
+
from agno.knowledge.reranker.base import Reranker
|
|
31
|
+
from agno.utils.log import log_debug, log_info, logger
|
|
32
|
+
from agno.vectordb.base import VectorDb
|
|
33
|
+
from agno.vectordb.distance import Distance
|
|
34
|
+
from agno.vectordb.pgvector.index import HNSW, Ivfflat
|
|
35
|
+
from agno.vectordb.search import SearchType
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PgVector(VectorDb):
|
|
39
|
+
"""
|
|
40
|
+
PgVector class for managing vector operations with PostgreSQL and pgvector.
|
|
41
|
+
|
|
42
|
+
This class provides methods for creating, inserting, searching, and managing
|
|
43
|
+
vector data in a PostgreSQL database using the pgvector extension.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
table_name: str,
|
|
49
|
+
schema: str = "ai",
|
|
50
|
+
name: Optional[str] = None,
|
|
51
|
+
description: Optional[str] = None,
|
|
52
|
+
id: Optional[str] = None,
|
|
53
|
+
db_url: Optional[str] = None,
|
|
54
|
+
db_engine: Optional[Engine] = None,
|
|
55
|
+
embedder: Optional[Embedder] = None,
|
|
56
|
+
search_type: SearchType = SearchType.vector,
|
|
57
|
+
vector_index: Union[Ivfflat, HNSW] = HNSW(),
|
|
58
|
+
distance: Distance = Distance.cosine,
|
|
59
|
+
prefix_match: bool = False,
|
|
60
|
+
vector_score_weight: float = 0.5,
|
|
61
|
+
content_language: str = "english",
|
|
62
|
+
schema_version: int = 1,
|
|
63
|
+
auto_upgrade_schema: bool = False,
|
|
64
|
+
reranker: Optional[Reranker] = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize the PgVector instance.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
table_name (str): Name of the table to store vector data.
|
|
71
|
+
schema (str): Database schema name.
|
|
72
|
+
name (Optional[str]): Name of the vector database.
|
|
73
|
+
description (Optional[str]): Description of the vector database.
|
|
74
|
+
db_url (Optional[str]): Database connection URL.
|
|
75
|
+
db_engine (Optional[Engine]): SQLAlchemy database engine.
|
|
76
|
+
embedder (Optional[Embedder]): Embedder instance for creating embeddings.
|
|
77
|
+
search_type (SearchType): Type of search to perform.
|
|
78
|
+
vector_index (Union[Ivfflat, HNSW]): Vector index configuration.
|
|
79
|
+
distance (Distance): Distance metric for vector comparisons.
|
|
80
|
+
prefix_match (bool): Enable prefix matching for full-text search.
|
|
81
|
+
vector_score_weight (float): Weight for vector similarity in hybrid search.
|
|
82
|
+
content_language (str): Language for full-text search.
|
|
83
|
+
schema_version (int): Version of the database schema.
|
|
84
|
+
auto_upgrade_schema (bool): Automatically upgrade schema if True.
|
|
85
|
+
"""
|
|
86
|
+
if not table_name:
|
|
87
|
+
raise ValueError("Table name must be provided.")
|
|
88
|
+
|
|
89
|
+
if db_engine is None and db_url is None:
|
|
90
|
+
raise ValueError("Either 'db_url' or 'db_engine' must be provided.")
|
|
91
|
+
|
|
92
|
+
if id is None:
|
|
93
|
+
base_seed = db_url or str(db_engine.url) # type: ignore
|
|
94
|
+
schema_suffix = table_name if table_name is not None else "ai"
|
|
95
|
+
seed = f"{base_seed}#{schema_suffix}"
|
|
96
|
+
id = generate_id(seed)
|
|
97
|
+
|
|
98
|
+
# Initialize base class with name and description
|
|
99
|
+
super().__init__(id=id, name=name, description=description)
|
|
100
|
+
|
|
101
|
+
if db_engine is None:
|
|
102
|
+
if db_url is None:
|
|
103
|
+
raise ValueError("Must provide 'db_url' if 'db_engine' is None.")
|
|
104
|
+
try:
|
|
105
|
+
db_engine = create_engine(db_url)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error(f"Failed to create engine from 'db_url': {e}")
|
|
108
|
+
raise
|
|
109
|
+
|
|
110
|
+
# Database settings
|
|
111
|
+
self.table_name: str = table_name
|
|
112
|
+
self.schema: str = schema
|
|
113
|
+
self.db_url: Optional[str] = db_url
|
|
114
|
+
self.db_engine: Engine = db_engine
|
|
115
|
+
self.metadata: MetaData = MetaData(schema=self.schema)
|
|
116
|
+
|
|
117
|
+
# Embedder for embedding the document contents
|
|
118
|
+
if embedder is None:
|
|
119
|
+
from agno.knowledge.embedder.openai import OpenAIEmbedder
|
|
120
|
+
|
|
121
|
+
embedder = OpenAIEmbedder()
|
|
122
|
+
log_info("Embedder not provided, using OpenAIEmbedder as default.")
|
|
123
|
+
self.embedder: Embedder = embedder
|
|
124
|
+
self.dimensions: Optional[int] = self.embedder.dimensions
|
|
125
|
+
|
|
126
|
+
if self.dimensions is None:
|
|
127
|
+
raise ValueError("Embedder.dimensions must be set.")
|
|
128
|
+
|
|
129
|
+
# Search type
|
|
130
|
+
self.search_type: SearchType = search_type
|
|
131
|
+
# Distance metric
|
|
132
|
+
self.distance: Distance = distance
|
|
133
|
+
# Index for the table
|
|
134
|
+
self.vector_index: Union[Ivfflat, HNSW] = vector_index
|
|
135
|
+
# Enable prefix matching for full-text search
|
|
136
|
+
self.prefix_match: bool = prefix_match
|
|
137
|
+
# Weight for the vector similarity score in hybrid search
|
|
138
|
+
self.vector_score_weight: float = vector_score_weight
|
|
139
|
+
# Content language for full-text search
|
|
140
|
+
self.content_language: str = content_language
|
|
141
|
+
|
|
142
|
+
# Table schema version
|
|
143
|
+
self.schema_version: int = schema_version
|
|
144
|
+
# Automatically upgrade schema if True
|
|
145
|
+
self.auto_upgrade_schema: bool = auto_upgrade_schema
|
|
146
|
+
|
|
147
|
+
# Reranker instance
|
|
148
|
+
self.reranker: Optional[Reranker] = reranker
|
|
149
|
+
|
|
150
|
+
# Database session
|
|
151
|
+
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
152
|
+
# Database table
|
|
153
|
+
self.table: Table = self.get_table()
|
|
154
|
+
log_debug(f"Initialized PgVector with table '{self.schema}.{self.table_name}'")
|
|
155
|
+
|
|
156
|
+
def get_table_v1(self) -> Table:
|
|
157
|
+
"""
|
|
158
|
+
Get the SQLAlchemy Table object for schema version 1.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Table: SQLAlchemy Table object representing the database table.
|
|
162
|
+
"""
|
|
163
|
+
if self.dimensions is None:
|
|
164
|
+
raise ValueError("Embedder dimensions are not set.")
|
|
165
|
+
table = Table(
|
|
166
|
+
self.table_name,
|
|
167
|
+
self.metadata,
|
|
168
|
+
Column("id", String, primary_key=True),
|
|
169
|
+
Column("name", String),
|
|
170
|
+
Column("meta_data", postgresql.JSONB, server_default=text("'{}'::jsonb")),
|
|
171
|
+
Column("filters", postgresql.JSONB, server_default=text("'{}'::jsonb"), nullable=True),
|
|
172
|
+
Column("content", postgresql.TEXT),
|
|
173
|
+
Column("embedding", Vector(self.dimensions)),
|
|
174
|
+
Column("usage", postgresql.JSONB),
|
|
175
|
+
Column("created_at", DateTime(timezone=True), server_default=func.now()),
|
|
176
|
+
Column("updated_at", DateTime(timezone=True), onupdate=func.now()),
|
|
177
|
+
Column("content_hash", String),
|
|
178
|
+
Column("content_id", String),
|
|
179
|
+
extend_existing=True,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Add indexes
|
|
183
|
+
Index(f"idx_{self.table_name}_id", table.c.id)
|
|
184
|
+
Index(f"idx_{self.table_name}_name", table.c.name)
|
|
185
|
+
Index(f"idx_{self.table_name}_content_hash", table.c.content_hash)
|
|
186
|
+
Index(f"idx_{self.table_name}_content_id", table.c.content_id)
|
|
187
|
+
return table
|
|
188
|
+
|
|
189
|
+
def get_table(self) -> Table:
|
|
190
|
+
"""
|
|
191
|
+
Get the SQLAlchemy Table object based on the current schema version.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Table: SQLAlchemy Table object representing the database table.
|
|
195
|
+
"""
|
|
196
|
+
if self.schema_version == 1:
|
|
197
|
+
return self.get_table_v1()
|
|
198
|
+
else:
|
|
199
|
+
raise NotImplementedError(f"Unsupported schema version: {self.schema_version}")
|
|
200
|
+
|
|
201
|
+
def table_exists(self) -> bool:
|
|
202
|
+
"""
|
|
203
|
+
Check if the table exists in the database.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
bool: True if the table exists, False otherwise.
|
|
207
|
+
"""
|
|
208
|
+
log_debug(f"Checking if table '{self.table.fullname}' exists.")
|
|
209
|
+
try:
|
|
210
|
+
return inspect(self.db_engine).has_table(self.table_name, schema=self.schema)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
logger.error(f"Error checking if table exists: {e}")
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
def create(self) -> None:
|
|
216
|
+
"""
|
|
217
|
+
Create the table if it does not exist.
|
|
218
|
+
"""
|
|
219
|
+
if not self.table_exists():
|
|
220
|
+
with self.Session() as sess, sess.begin():
|
|
221
|
+
log_debug("Creating extension: vector")
|
|
222
|
+
sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
223
|
+
if self.schema is not None:
|
|
224
|
+
log_debug(f"Creating schema: {self.schema}")
|
|
225
|
+
sess.execute(text(f"CREATE SCHEMA IF NOT EXISTS {self.schema};"))
|
|
226
|
+
log_debug(f"Creating table: {self.table_name}")
|
|
227
|
+
self.table.create(self.db_engine)
|
|
228
|
+
|
|
229
|
+
async def async_create(self) -> None:
|
|
230
|
+
"""Create the table asynchronously by running in a thread."""
|
|
231
|
+
await asyncio.to_thread(self.create)
|
|
232
|
+
|
|
233
|
+
def _record_exists(self, column, value) -> bool:
|
|
234
|
+
"""
|
|
235
|
+
Check if a record with the given column value exists in the table.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
column: The column to check.
|
|
239
|
+
value: The value to search for.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
bool: True if the record exists, False otherwise.
|
|
243
|
+
"""
|
|
244
|
+
try:
|
|
245
|
+
with self.Session() as sess, sess.begin():
|
|
246
|
+
stmt = select(1).where(column == value).limit(1)
|
|
247
|
+
result = sess.execute(stmt).first()
|
|
248
|
+
return result is not None
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(f"Error checking if record exists: {e}")
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
def name_exists(self, name: str) -> bool:
|
|
254
|
+
"""
|
|
255
|
+
Check if a document with the given name exists in the table.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
name (str): The name to check.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
bool: True if a document with the name exists, False otherwise.
|
|
262
|
+
"""
|
|
263
|
+
return self._record_exists(self.table.c.name, name)
|
|
264
|
+
|
|
265
|
+
async def async_name_exists(self, name: str) -> bool:
|
|
266
|
+
"""Check if name exists asynchronously by running in a thread."""
|
|
267
|
+
return await asyncio.to_thread(self.name_exists, name)
|
|
268
|
+
|
|
269
|
+
def id_exists(self, id: str) -> bool:
|
|
270
|
+
"""
|
|
271
|
+
Check if a document with the given ID exists in the table.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
id (str): The ID to check.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
bool: True if a document with the ID exists, False otherwise.
|
|
278
|
+
"""
|
|
279
|
+
return self._record_exists(self.table.c.id, id)
|
|
280
|
+
|
|
281
|
+
def content_hash_exists(self, content_hash: str) -> bool:
|
|
282
|
+
"""
|
|
283
|
+
Check if a document with the given content hash exists in the table.
|
|
284
|
+
"""
|
|
285
|
+
return self._record_exists(self.table.c.content_hash, content_hash)
|
|
286
|
+
|
|
287
|
+
def _clean_content(self, content: str) -> str:
|
|
288
|
+
"""
|
|
289
|
+
Clean the content by replacing null characters.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
content (str): The content to clean.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
str: The cleaned content.
|
|
296
|
+
"""
|
|
297
|
+
return content.replace("\x00", "\ufffd")
|
|
298
|
+
|
|
299
|
+
def insert(
|
|
300
|
+
self,
|
|
301
|
+
content_hash: str,
|
|
302
|
+
documents: List[Document],
|
|
303
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
304
|
+
batch_size: int = 100,
|
|
305
|
+
) -> None:
|
|
306
|
+
"""
|
|
307
|
+
Insert documents into the database.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
content_hash (str): The content hash to insert.
|
|
311
|
+
documents (List[Document]): List of documents to insert.
|
|
312
|
+
filters (Optional[Dict[str, Any]]): Filters to apply to the documents.
|
|
313
|
+
batch_size (int): Number of documents to insert in each batch.
|
|
314
|
+
"""
|
|
315
|
+
try:
|
|
316
|
+
with self.Session() as sess:
|
|
317
|
+
for i in range(0, len(documents), batch_size):
|
|
318
|
+
batch_docs = documents[i : i + batch_size]
|
|
319
|
+
log_debug(f"Processing batch starting at index {i}, size: {len(batch_docs)}")
|
|
320
|
+
try:
|
|
321
|
+
# Prepare documents for insertion
|
|
322
|
+
batch_records = []
|
|
323
|
+
for doc in batch_docs:
|
|
324
|
+
try:
|
|
325
|
+
batch_records.append(self._get_document_record(doc, filters, content_hash))
|
|
326
|
+
except Exception as e:
|
|
327
|
+
logger.error(f"Error processing document '{doc.name}': {e}")
|
|
328
|
+
|
|
329
|
+
# Insert the batch of records
|
|
330
|
+
insert_stmt = postgresql.insert(self.table)
|
|
331
|
+
sess.execute(insert_stmt, batch_records)
|
|
332
|
+
sess.commit() # Commit batch independently
|
|
333
|
+
log_info(f"Inserted batch of {len(batch_records)} documents.")
|
|
334
|
+
except Exception as e:
|
|
335
|
+
logger.error(f"Error with batch starting at index {i}: {e}")
|
|
336
|
+
sess.rollback() # Rollback the current batch if there's an error
|
|
337
|
+
raise
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.error(f"Error inserting documents: {e}")
|
|
340
|
+
raise
|
|
341
|
+
|
|
342
|
+
async def async_insert(
|
|
343
|
+
self,
|
|
344
|
+
content_hash: str,
|
|
345
|
+
documents: List[Document],
|
|
346
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
347
|
+
batch_size: int = 100,
|
|
348
|
+
) -> None:
|
|
349
|
+
"""Insert documents asynchronously with parallel embedding."""
|
|
350
|
+
try:
|
|
351
|
+
with self.Session() as sess:
|
|
352
|
+
for i in range(0, len(documents), batch_size):
|
|
353
|
+
batch_docs = documents[i : i + batch_size]
|
|
354
|
+
log_debug(f"Processing batch starting at index {i}, size: {len(batch_docs)}")
|
|
355
|
+
try:
|
|
356
|
+
# Embed all documents in the batch
|
|
357
|
+
await self._async_embed_documents(batch_docs)
|
|
358
|
+
|
|
359
|
+
# Prepare documents for insertion
|
|
360
|
+
batch_records = []
|
|
361
|
+
for doc in batch_docs:
|
|
362
|
+
try:
|
|
363
|
+
cleaned_content = self._clean_content(doc.content)
|
|
364
|
+
record_id = doc.id or content_hash
|
|
365
|
+
|
|
366
|
+
meta_data = doc.meta_data or {}
|
|
367
|
+
if filters:
|
|
368
|
+
meta_data.update(filters)
|
|
369
|
+
|
|
370
|
+
record = {
|
|
371
|
+
"id": record_id,
|
|
372
|
+
"name": doc.name,
|
|
373
|
+
"meta_data": doc.meta_data,
|
|
374
|
+
"filters": filters,
|
|
375
|
+
"content": cleaned_content,
|
|
376
|
+
"embedding": doc.embedding,
|
|
377
|
+
"usage": doc.usage,
|
|
378
|
+
"content_hash": content_hash,
|
|
379
|
+
"content_id": doc.content_id,
|
|
380
|
+
}
|
|
381
|
+
batch_records.append(record)
|
|
382
|
+
except Exception as e:
|
|
383
|
+
logger.error(f"Error processing document '{doc.name}': {e}")
|
|
384
|
+
|
|
385
|
+
# Insert the batch of records
|
|
386
|
+
if batch_records:
|
|
387
|
+
insert_stmt = postgresql.insert(self.table)
|
|
388
|
+
sess.execute(insert_stmt, batch_records)
|
|
389
|
+
sess.commit() # Commit batch independently
|
|
390
|
+
log_info(f"Inserted batch of {len(batch_records)} documents.")
|
|
391
|
+
except Exception as e:
|
|
392
|
+
logger.error(f"Error with batch starting at index {i}: {e}")
|
|
393
|
+
sess.rollback() # Rollback the current batch if there's an error
|
|
394
|
+
raise
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"Error inserting documents: {e}")
|
|
397
|
+
raise
|
|
398
|
+
|
|
399
|
+
def upsert_available(self) -> bool:
|
|
400
|
+
"""
|
|
401
|
+
Check if upsert operation is available.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
bool: Always returns True for PgVector.
|
|
405
|
+
"""
|
|
406
|
+
return True
|
|
407
|
+
|
|
408
|
+
def upsert(
|
|
409
|
+
self,
|
|
410
|
+
content_hash: str,
|
|
411
|
+
documents: List[Document],
|
|
412
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
413
|
+
batch_size: int = 100,
|
|
414
|
+
) -> None:
|
|
415
|
+
"""
|
|
416
|
+
Upsert documents by content hash.
|
|
417
|
+
First delete all documents with the same content hash.
|
|
418
|
+
Then upsert the new documents.
|
|
419
|
+
"""
|
|
420
|
+
try:
|
|
421
|
+
if self.content_hash_exists(content_hash):
|
|
422
|
+
self._delete_by_content_hash(content_hash)
|
|
423
|
+
self._upsert(content_hash, documents, filters, batch_size)
|
|
424
|
+
except Exception as e:
|
|
425
|
+
logger.error(f"Error upserting documents by content hash: {e}")
|
|
426
|
+
raise
|
|
427
|
+
|
|
428
|
+
def _upsert(
|
|
429
|
+
self,
|
|
430
|
+
content_hash: str,
|
|
431
|
+
documents: List[Document],
|
|
432
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
433
|
+
batch_size: int = 100,
|
|
434
|
+
) -> None:
|
|
435
|
+
"""
|
|
436
|
+
Upsert (insert or update) documents in the database.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
documents (List[Document]): List of documents to upsert.
|
|
440
|
+
filters (Optional[Dict[str, Any]]): Filters to apply to the documents.
|
|
441
|
+
batch_size (int): Number of documents to upsert in each batch.
|
|
442
|
+
"""
|
|
443
|
+
try:
|
|
444
|
+
with self.Session() as sess:
|
|
445
|
+
for i in range(0, len(documents), batch_size):
|
|
446
|
+
batch_docs = documents[i : i + batch_size]
|
|
447
|
+
log_info(f"Processing batch starting at index {i}, size: {len(batch_docs)}")
|
|
448
|
+
try:
|
|
449
|
+
# Prepare documents for upserting
|
|
450
|
+
batch_records_dict: Dict[str, Dict[str, Any]] = {} # Use dict to deduplicate by ID
|
|
451
|
+
for doc in batch_docs:
|
|
452
|
+
try:
|
|
453
|
+
batch_records_dict[doc.id] = self._get_document_record(doc, filters, content_hash) # type: ignore
|
|
454
|
+
except Exception as e:
|
|
455
|
+
logger.error(f"Error processing document '{doc.name}': {e}")
|
|
456
|
+
|
|
457
|
+
# Convert dict to list for upsert
|
|
458
|
+
batch_records = list(batch_records_dict.values())
|
|
459
|
+
if not batch_records:
|
|
460
|
+
log_info("No valid records to upsert in this batch.")
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Upsert the batch of records
|
|
464
|
+
insert_stmt = postgresql.insert(self.table).values(batch_records)
|
|
465
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
466
|
+
index_elements=["id"],
|
|
467
|
+
set_={
|
|
468
|
+
"name": insert_stmt.excluded.name,
|
|
469
|
+
"meta_data": insert_stmt.excluded.meta_data,
|
|
470
|
+
"filters": insert_stmt.excluded.filters,
|
|
471
|
+
"content": insert_stmt.excluded.content,
|
|
472
|
+
"embedding": insert_stmt.excluded.embedding,
|
|
473
|
+
"usage": insert_stmt.excluded.usage,
|
|
474
|
+
"content_hash": insert_stmt.excluded.content_hash,
|
|
475
|
+
"content_id": insert_stmt.excluded.content_id,
|
|
476
|
+
},
|
|
477
|
+
)
|
|
478
|
+
sess.execute(upsert_stmt)
|
|
479
|
+
sess.commit() # Commit batch independently
|
|
480
|
+
log_info(f"Upserted batch of {len(batch_records)} documents.")
|
|
481
|
+
except Exception as e:
|
|
482
|
+
logger.error(f"Error with batch starting at index {i}: {e}")
|
|
483
|
+
sess.rollback() # Rollback the current batch if there's an error
|
|
484
|
+
raise
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.error(f"Error upserting documents: {e}")
|
|
487
|
+
raise
|
|
488
|
+
|
|
489
|
+
def _get_document_record(
|
|
490
|
+
self, doc: Document, filters: Optional[Dict[str, Any]] = None, content_hash: str = ""
|
|
491
|
+
) -> Dict[str, Any]:
|
|
492
|
+
doc.embed(embedder=self.embedder)
|
|
493
|
+
cleaned_content = self._clean_content(doc.content)
|
|
494
|
+
record_id = doc.id or content_hash
|
|
495
|
+
|
|
496
|
+
meta_data = doc.meta_data or {}
|
|
497
|
+
if filters:
|
|
498
|
+
meta_data.update(filters)
|
|
499
|
+
|
|
500
|
+
return {
|
|
501
|
+
"id": record_id,
|
|
502
|
+
"name": doc.name,
|
|
503
|
+
"meta_data": doc.meta_data,
|
|
504
|
+
"filters": filters,
|
|
505
|
+
"content": cleaned_content,
|
|
506
|
+
"embedding": doc.embedding,
|
|
507
|
+
"usage": doc.usage,
|
|
508
|
+
"content_hash": content_hash,
|
|
509
|
+
"content_id": doc.content_id,
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
async def _async_embed_documents(self, batch_docs: List[Document]) -> None:
|
|
513
|
+
"""
|
|
514
|
+
Embed a batch of documents using either batch embedding or individual embedding.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
batch_docs: List of documents to embed
|
|
518
|
+
"""
|
|
519
|
+
if self.embedder.enable_batch and hasattr(self.embedder, "async_get_embeddings_batch_and_usage"):
|
|
520
|
+
# Use batch embedding when enabled and supported
|
|
521
|
+
try:
|
|
522
|
+
# Extract content from all documents
|
|
523
|
+
doc_contents = [doc.content for doc in batch_docs]
|
|
524
|
+
|
|
525
|
+
# Get batch embeddings and usage
|
|
526
|
+
embeddings, usages = await self.embedder.async_get_embeddings_batch_and_usage(doc_contents)
|
|
527
|
+
|
|
528
|
+
# Process documents with pre-computed embeddings
|
|
529
|
+
for j, doc in enumerate(batch_docs):
|
|
530
|
+
try:
|
|
531
|
+
if j < len(embeddings):
|
|
532
|
+
doc.embedding = embeddings[j]
|
|
533
|
+
doc.usage = usages[j] if j < len(usages) else None
|
|
534
|
+
except Exception as e:
|
|
535
|
+
logger.error(f"Error assigning batch embedding to document '{doc.name}': {e}")
|
|
536
|
+
|
|
537
|
+
except Exception as e:
|
|
538
|
+
# Check if this is a rate limit error - don't fall back as it would make things worse
|
|
539
|
+
error_str = str(e).lower()
|
|
540
|
+
is_rate_limit = any(
|
|
541
|
+
phrase in error_str
|
|
542
|
+
for phrase in ["rate limit", "too many requests", "429", "trial key", "api calls / minute"]
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
if is_rate_limit:
|
|
546
|
+
logger.error(f"Rate limit detected during batch embedding. {e}")
|
|
547
|
+
raise e
|
|
548
|
+
else:
|
|
549
|
+
logger.warning(f"Async batch embedding failed, falling back to individual embeddings: {e}")
|
|
550
|
+
# Fall back to individual embedding
|
|
551
|
+
embed_tasks = [doc.async_embed(embedder=self.embedder) for doc in batch_docs]
|
|
552
|
+
await asyncio.gather(*embed_tasks, return_exceptions=True)
|
|
553
|
+
else:
|
|
554
|
+
# Use individual embedding
|
|
555
|
+
embed_tasks = [doc.async_embed(embedder=self.embedder) for doc in batch_docs]
|
|
556
|
+
await asyncio.gather(*embed_tasks, return_exceptions=True)
|
|
557
|
+
|
|
558
|
+
async def async_upsert(
|
|
559
|
+
self,
|
|
560
|
+
content_hash: str,
|
|
561
|
+
documents: List[Document],
|
|
562
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
563
|
+
batch_size: int = 100,
|
|
564
|
+
) -> None:
|
|
565
|
+
"""Upsert documents asynchronously by running in a thread."""
|
|
566
|
+
try:
|
|
567
|
+
if self.content_hash_exists(content_hash):
|
|
568
|
+
self._delete_by_content_hash(content_hash)
|
|
569
|
+
await self._async_upsert(content_hash, documents, filters, batch_size)
|
|
570
|
+
except Exception as e:
|
|
571
|
+
logger.error(f"Error upserting documents by content hash: {e}")
|
|
572
|
+
raise
|
|
573
|
+
|
|
574
|
+
async def _async_upsert(
|
|
575
|
+
self,
|
|
576
|
+
content_hash: str,
|
|
577
|
+
documents: List[Document],
|
|
578
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
579
|
+
batch_size: int = 100,
|
|
580
|
+
) -> None:
|
|
581
|
+
"""
|
|
582
|
+
Upsert (insert or update) documents in the database.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
documents (List[Document]): List of documents to upsert.
|
|
586
|
+
filters (Optional[Dict[str, Any]]): Filters to apply to the documents.
|
|
587
|
+
batch_size (int): Number of documents to upsert in each batch.
|
|
588
|
+
"""
|
|
589
|
+
try:
|
|
590
|
+
with self.Session() as sess:
|
|
591
|
+
for i in range(0, len(documents), batch_size):
|
|
592
|
+
batch_docs = documents[i : i + batch_size]
|
|
593
|
+
log_info(f"Processing batch starting at index {i}, size: {len(batch_docs)}")
|
|
594
|
+
try:
|
|
595
|
+
# Embed all documents in the batch
|
|
596
|
+
await self._async_embed_documents(batch_docs)
|
|
597
|
+
|
|
598
|
+
# Prepare documents for upserting
|
|
599
|
+
batch_records_dict = {} # Use dict to deduplicate by ID
|
|
600
|
+
for doc in batch_docs:
|
|
601
|
+
try:
|
|
602
|
+
cleaned_content = self._clean_content(doc.content)
|
|
603
|
+
record_id = md5(cleaned_content.encode()).hexdigest()
|
|
604
|
+
|
|
605
|
+
meta_data = doc.meta_data or {}
|
|
606
|
+
if filters:
|
|
607
|
+
meta_data.update(filters)
|
|
608
|
+
|
|
609
|
+
record = {
|
|
610
|
+
"id": record_id, # use record_id as a reproducible id to avoid duplicates while upsert
|
|
611
|
+
"name": doc.name,
|
|
612
|
+
"meta_data": doc.meta_data,
|
|
613
|
+
"filters": filters,
|
|
614
|
+
"content": cleaned_content,
|
|
615
|
+
"embedding": doc.embedding,
|
|
616
|
+
"usage": doc.usage,
|
|
617
|
+
"content_hash": content_hash,
|
|
618
|
+
"content_id": doc.content_id,
|
|
619
|
+
}
|
|
620
|
+
batch_records_dict[record_id] = record # This deduplicates by ID
|
|
621
|
+
except Exception as e:
|
|
622
|
+
logger.error(f"Error processing document '{doc.name}': {e}")
|
|
623
|
+
|
|
624
|
+
# Convert dict to list for upsert
|
|
625
|
+
batch_records = list(batch_records_dict.values())
|
|
626
|
+
if not batch_records:
|
|
627
|
+
log_info("No valid records to upsert in this batch.")
|
|
628
|
+
continue
|
|
629
|
+
|
|
630
|
+
# Upsert the batch of records
|
|
631
|
+
insert_stmt = postgresql.insert(self.table).values(batch_records)
|
|
632
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
633
|
+
index_elements=["id"],
|
|
634
|
+
set_={
|
|
635
|
+
"name": insert_stmt.excluded.name,
|
|
636
|
+
"meta_data": insert_stmt.excluded.meta_data,
|
|
637
|
+
"filters": insert_stmt.excluded.filters,
|
|
638
|
+
"content": insert_stmt.excluded.content,
|
|
639
|
+
"embedding": insert_stmt.excluded.embedding,
|
|
640
|
+
"usage": insert_stmt.excluded.usage,
|
|
641
|
+
"content_hash": insert_stmt.excluded.content_hash,
|
|
642
|
+
"content_id": insert_stmt.excluded.content_id,
|
|
643
|
+
},
|
|
644
|
+
)
|
|
645
|
+
sess.execute(upsert_stmt)
|
|
646
|
+
sess.commit() # Commit batch independently
|
|
647
|
+
log_info(f"Upserted batch of {len(batch_records)} documents.")
|
|
648
|
+
except Exception as e:
|
|
649
|
+
logger.error(f"Error with batch starting at index {i}: {e}")
|
|
650
|
+
sess.rollback() # Rollback the current batch if there's an error
|
|
651
|
+
raise
|
|
652
|
+
except Exception as e:
|
|
653
|
+
logger.error(f"Error upserting documents: {e}")
|
|
654
|
+
raise
|
|
655
|
+
|
|
656
|
+
def update_metadata(self, content_id: str, metadata: Dict[str, Any]) -> None:
|
|
657
|
+
"""
|
|
658
|
+
Update the metadata for a document.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
id (str): The ID of the document.
|
|
662
|
+
metadata (Dict[str, Any]): The metadata to update.
|
|
663
|
+
"""
|
|
664
|
+
try:
|
|
665
|
+
with self.Session() as sess:
|
|
666
|
+
# Merge JSONB instead of overwriting: coalesce(existing, '{}') || :new
|
|
667
|
+
stmt = (
|
|
668
|
+
update(self.table)
|
|
669
|
+
.where(self.table.c.content_id == content_id)
|
|
670
|
+
.values(
|
|
671
|
+
meta_data=func.coalesce(self.table.c.meta_data, text("'{}'::jsonb")).op("||")(
|
|
672
|
+
bindparam("md", metadata, type_=postgresql.JSONB)
|
|
673
|
+
),
|
|
674
|
+
filters=func.coalesce(self.table.c.filters, text("'{}'::jsonb")).op("||")(
|
|
675
|
+
bindparam("ft", metadata, type_=postgresql.JSONB)
|
|
676
|
+
),
|
|
677
|
+
)
|
|
678
|
+
)
|
|
679
|
+
sess.execute(stmt)
|
|
680
|
+
sess.commit()
|
|
681
|
+
except Exception as e:
|
|
682
|
+
logger.error(f"Error updating metadata for document {content_id}: {e}")
|
|
683
|
+
raise
|
|
684
|
+
|
|
685
|
+
def search(
|
|
686
|
+
self, query: str, limit: int = 5, filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None
|
|
687
|
+
) -> List[Document]:
|
|
688
|
+
"""
|
|
689
|
+
Perform a search based on the configured search type.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
query (str): The search query.
|
|
693
|
+
limit (int): Maximum number of results to return.
|
|
694
|
+
filters (Optional[Union[Dict[str, Any], List[FilterExpr]]]): Filters to apply to the search.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
List[Document]: List of matching documents.
|
|
698
|
+
"""
|
|
699
|
+
if self.search_type == SearchType.vector:
|
|
700
|
+
return self.vector_search(query=query, limit=limit, filters=filters)
|
|
701
|
+
elif self.search_type == SearchType.keyword:
|
|
702
|
+
return self.keyword_search(query=query, limit=limit, filters=filters)
|
|
703
|
+
elif self.search_type == SearchType.hybrid:
|
|
704
|
+
return self.hybrid_search(query=query, limit=limit, filters=filters)
|
|
705
|
+
else:
|
|
706
|
+
logger.error(f"Invalid search type '{self.search_type}'.")
|
|
707
|
+
return []
|
|
708
|
+
|
|
709
|
+
async def async_search(
|
|
710
|
+
self, query: str, limit: int = 5, filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None
|
|
711
|
+
) -> List[Document]:
|
|
712
|
+
"""Search asynchronously by running in a thread."""
|
|
713
|
+
return await asyncio.to_thread(self.search, query, limit, filters)
|
|
714
|
+
|
|
715
|
+
def _dsl_to_sqlalchemy(self, filter_expr, table) -> ColumnElement[bool]:
|
|
716
|
+
op = filter_expr["op"]
|
|
717
|
+
|
|
718
|
+
if op == "EQ":
|
|
719
|
+
return table.c.meta_data[filter_expr["key"]].astext == str(filter_expr["value"])
|
|
720
|
+
elif op == "IN":
|
|
721
|
+
# Postgres JSONB array containment
|
|
722
|
+
return table.c.meta_data[filter_expr["key"]].astext.in_([str(v) for v in filter_expr["values"]])
|
|
723
|
+
elif op == "GT":
|
|
724
|
+
return table.c.meta_data[filter_expr["key"]].astext.cast(Integer) > filter_expr["value"]
|
|
725
|
+
elif op == "LT":
|
|
726
|
+
return table.c.meta_data[filter_expr["key"]].astext.cast(Integer) < filter_expr["value"]
|
|
727
|
+
elif op == "NOT":
|
|
728
|
+
return not_(self._dsl_to_sqlalchemy(filter_expr["condition"], table))
|
|
729
|
+
elif op == "AND":
|
|
730
|
+
return and_(*[self._dsl_to_sqlalchemy(cond, table) for cond in filter_expr["conditions"]])
|
|
731
|
+
elif op == "OR":
|
|
732
|
+
return or_(*[self._dsl_to_sqlalchemy(cond, table) for cond in filter_expr["conditions"]])
|
|
733
|
+
else:
|
|
734
|
+
raise ValueError(f"Unknown filter operator: {op}")
|
|
735
|
+
|
|
736
|
+
def vector_search(
|
|
737
|
+
self, query: str, limit: int = 5, filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None
|
|
738
|
+
) -> List[Document]:
|
|
739
|
+
"""
|
|
740
|
+
Perform a vector similarity search.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
query (str): The search query.
|
|
744
|
+
limit (int): Maximum number of results to return.
|
|
745
|
+
filters (Optional[Union[Dict[str, Any], List[FilterExpr]]]): Filters to apply to the search.
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
List[Document]: List of matching documents.
|
|
749
|
+
"""
|
|
750
|
+
try:
|
|
751
|
+
# Get the embedding for the query string
|
|
752
|
+
query_embedding = self.embedder.get_embedding(query)
|
|
753
|
+
if query_embedding is None:
|
|
754
|
+
logger.error(f"Error getting embedding for Query: {query}")
|
|
755
|
+
return []
|
|
756
|
+
|
|
757
|
+
# Define the columns to select
|
|
758
|
+
columns = [
|
|
759
|
+
self.table.c.id,
|
|
760
|
+
self.table.c.name,
|
|
761
|
+
self.table.c.meta_data,
|
|
762
|
+
self.table.c.content,
|
|
763
|
+
self.table.c.embedding,
|
|
764
|
+
self.table.c.usage,
|
|
765
|
+
]
|
|
766
|
+
|
|
767
|
+
# Build the base statement
|
|
768
|
+
stmt = select(*columns)
|
|
769
|
+
|
|
770
|
+
# Apply filters if provided
|
|
771
|
+
if filters is not None:
|
|
772
|
+
# Handle dict filters
|
|
773
|
+
if isinstance(filters, dict):
|
|
774
|
+
stmt = stmt.where(self.table.c.meta_data.contains(filters))
|
|
775
|
+
# Handle FilterExpr DSL
|
|
776
|
+
else:
|
|
777
|
+
# Convert each DSL expression to SQLAlchemy and AND them together
|
|
778
|
+
sqlalchemy_conditions = [
|
|
779
|
+
self._dsl_to_sqlalchemy(f.to_dict() if hasattr(f, "to_dict") else f, self.table)
|
|
780
|
+
for f in filters
|
|
781
|
+
]
|
|
782
|
+
stmt = stmt.where(and_(*sqlalchemy_conditions))
|
|
783
|
+
|
|
784
|
+
# Order the results based on the distance metric
|
|
785
|
+
if self.distance == Distance.l2:
|
|
786
|
+
stmt = stmt.order_by(self.table.c.embedding.l2_distance(query_embedding))
|
|
787
|
+
elif self.distance == Distance.cosine:
|
|
788
|
+
stmt = stmt.order_by(self.table.c.embedding.cosine_distance(query_embedding))
|
|
789
|
+
elif self.distance == Distance.max_inner_product:
|
|
790
|
+
stmt = stmt.order_by(self.table.c.embedding.max_inner_product(query_embedding))
|
|
791
|
+
else:
|
|
792
|
+
logger.error(f"Unknown distance metric: {self.distance}")
|
|
793
|
+
return []
|
|
794
|
+
|
|
795
|
+
# Limit the number of results
|
|
796
|
+
stmt = stmt.limit(limit)
|
|
797
|
+
|
|
798
|
+
# Log the query for debugging
|
|
799
|
+
log_debug(f"Vector search query: {stmt}")
|
|
800
|
+
|
|
801
|
+
# Execute the query
|
|
802
|
+
try:
|
|
803
|
+
with self.Session() as sess, sess.begin():
|
|
804
|
+
if self.vector_index is not None:
|
|
805
|
+
if isinstance(self.vector_index, Ivfflat):
|
|
806
|
+
sess.execute(text(f"SET LOCAL ivfflat.probes = {self.vector_index.probes}"))
|
|
807
|
+
elif isinstance(self.vector_index, HNSW):
|
|
808
|
+
sess.execute(text(f"SET LOCAL hnsw.ef_search = {self.vector_index.ef_search}"))
|
|
809
|
+
results = sess.execute(stmt).fetchall()
|
|
810
|
+
except Exception as e:
|
|
811
|
+
logger.error(f"Error performing semantic search: {e}")
|
|
812
|
+
logger.error("Table might not exist, creating for future use")
|
|
813
|
+
self.create()
|
|
814
|
+
return []
|
|
815
|
+
|
|
816
|
+
# Process the results and convert to Document objects
|
|
817
|
+
search_results: List[Document] = []
|
|
818
|
+
for result in results:
|
|
819
|
+
search_results.append(
|
|
820
|
+
Document(
|
|
821
|
+
id=result.id,
|
|
822
|
+
name=result.name,
|
|
823
|
+
meta_data=result.meta_data,
|
|
824
|
+
content=result.content,
|
|
825
|
+
embedder=self.embedder,
|
|
826
|
+
embedding=result.embedding,
|
|
827
|
+
usage=result.usage,
|
|
828
|
+
)
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
if self.reranker:
|
|
832
|
+
search_results = self.reranker.rerank(query=query, documents=search_results)
|
|
833
|
+
|
|
834
|
+
log_info(f"Found {len(search_results)} documents")
|
|
835
|
+
return search_results
|
|
836
|
+
except Exception as e:
|
|
837
|
+
logger.error(f"Error during vector search: {e}")
|
|
838
|
+
return []
|
|
839
|
+
|
|
840
|
+
def enable_prefix_matching(self, query: str) -> str:
|
|
841
|
+
"""
|
|
842
|
+
Preprocess the query for prefix matching.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
query (str): The original query.
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
str: The processed query with prefix matching enabled.
|
|
849
|
+
"""
|
|
850
|
+
# Append '*' to each word for prefix matching
|
|
851
|
+
words = query.strip().split()
|
|
852
|
+
processed_words = [word + "*" for word in words]
|
|
853
|
+
return " ".join(processed_words)
|
|
854
|
+
|
|
855
|
+
def keyword_search(
|
|
856
|
+
self, query: str, limit: int = 5, filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None
|
|
857
|
+
) -> List[Document]:
|
|
858
|
+
"""
|
|
859
|
+
Perform a keyword search on the 'content' column.
|
|
860
|
+
|
|
861
|
+
Args:
|
|
862
|
+
query (str): The search query.
|
|
863
|
+
limit (int): Maximum number of results to return.
|
|
864
|
+
filters (Optional[Union[Dict[str, Any], List[FilterExpr]]]): Filters to apply to the search.
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
List[Document]: List of matching documents.
|
|
868
|
+
"""
|
|
869
|
+
try:
|
|
870
|
+
# Define the columns to select
|
|
871
|
+
columns = [
|
|
872
|
+
self.table.c.id,
|
|
873
|
+
self.table.c.name,
|
|
874
|
+
self.table.c.meta_data,
|
|
875
|
+
self.table.c.content,
|
|
876
|
+
self.table.c.embedding,
|
|
877
|
+
self.table.c.usage,
|
|
878
|
+
]
|
|
879
|
+
|
|
880
|
+
# Build the base statement
|
|
881
|
+
stmt = select(*columns)
|
|
882
|
+
|
|
883
|
+
# Build the text search vector
|
|
884
|
+
ts_vector = func.to_tsvector(self.content_language, self.table.c.content)
|
|
885
|
+
# Create the ts_query using websearch_to_tsquery with parameter binding
|
|
886
|
+
processed_query = self.enable_prefix_matching(query) if self.prefix_match else query
|
|
887
|
+
ts_query = func.websearch_to_tsquery(self.content_language, bindparam("query", value=processed_query))
|
|
888
|
+
# Compute the text rank
|
|
889
|
+
text_rank = func.ts_rank_cd(ts_vector, ts_query)
|
|
890
|
+
|
|
891
|
+
# Apply filters if provided
|
|
892
|
+
if filters is not None:
|
|
893
|
+
# Handle dict filters
|
|
894
|
+
if isinstance(filters, dict):
|
|
895
|
+
stmt = stmt.where(self.table.c.meta_data.contains(filters))
|
|
896
|
+
# Handle FilterExpr DSL
|
|
897
|
+
else:
|
|
898
|
+
# Convert each DSL expression to SQLAlchemy and AND them together
|
|
899
|
+
sqlalchemy_conditions = [
|
|
900
|
+
self._dsl_to_sqlalchemy(f.to_dict() if hasattr(f, "to_dict") else f, self.table)
|
|
901
|
+
for f in filters
|
|
902
|
+
]
|
|
903
|
+
stmt = stmt.where(and_(*sqlalchemy_conditions))
|
|
904
|
+
|
|
905
|
+
# Order by the relevance rank
|
|
906
|
+
stmt = stmt.order_by(text_rank.desc())
|
|
907
|
+
|
|
908
|
+
# Limit the number of results
|
|
909
|
+
stmt = stmt.limit(limit)
|
|
910
|
+
|
|
911
|
+
# Log the query for debugging
|
|
912
|
+
log_debug(f"Keyword search query: {stmt}")
|
|
913
|
+
|
|
914
|
+
# Execute the query
|
|
915
|
+
try:
|
|
916
|
+
with self.Session() as sess, sess.begin():
|
|
917
|
+
results = sess.execute(stmt).fetchall()
|
|
918
|
+
except Exception as e:
|
|
919
|
+
logger.error(f"Error performing keyword search: {e}")
|
|
920
|
+
logger.error("Table might not exist, creating for future use")
|
|
921
|
+
self.create()
|
|
922
|
+
return []
|
|
923
|
+
|
|
924
|
+
# Process the results and convert to Document objects
|
|
925
|
+
search_results: List[Document] = []
|
|
926
|
+
for result in results:
|
|
927
|
+
search_results.append(
|
|
928
|
+
Document(
|
|
929
|
+
id=result.id,
|
|
930
|
+
name=result.name,
|
|
931
|
+
meta_data=result.meta_data,
|
|
932
|
+
content=result.content,
|
|
933
|
+
embedder=self.embedder,
|
|
934
|
+
embedding=result.embedding,
|
|
935
|
+
usage=result.usage,
|
|
936
|
+
)
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
log_info(f"Found {len(search_results)} documents")
|
|
940
|
+
return search_results
|
|
941
|
+
except Exception as e:
|
|
942
|
+
logger.error(f"Error during keyword search: {e}")
|
|
943
|
+
return []
|
|
944
|
+
|
|
945
|
+
def hybrid_search(
|
|
946
|
+
self,
|
|
947
|
+
query: str,
|
|
948
|
+
limit: int = 5,
|
|
949
|
+
filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None,
|
|
950
|
+
) -> List[Document]:
|
|
951
|
+
"""
|
|
952
|
+
Perform a hybrid search combining vector similarity and full-text search.
|
|
953
|
+
|
|
954
|
+
Args:
|
|
955
|
+
query (str): The search query.
|
|
956
|
+
limit (int): Maximum number of results to return.
|
|
957
|
+
filters (Optional[Union[Dict[str, Any], List[FilterExpr]]]): Filters to apply to the search.
|
|
958
|
+
|
|
959
|
+
Returns:
|
|
960
|
+
List[Document]: List of matching documents.
|
|
961
|
+
"""
|
|
962
|
+
try:
|
|
963
|
+
# Get the embedding for the query string
|
|
964
|
+
query_embedding = self.embedder.get_embedding(query)
|
|
965
|
+
if query_embedding is None:
|
|
966
|
+
logger.error(f"Error getting embedding for Query: {query}")
|
|
967
|
+
return []
|
|
968
|
+
|
|
969
|
+
# Define the columns to select
|
|
970
|
+
columns = [
|
|
971
|
+
self.table.c.id,
|
|
972
|
+
self.table.c.name,
|
|
973
|
+
self.table.c.meta_data,
|
|
974
|
+
self.table.c.content,
|
|
975
|
+
self.table.c.embedding,
|
|
976
|
+
self.table.c.usage,
|
|
977
|
+
]
|
|
978
|
+
|
|
979
|
+
# Build the text search vector
|
|
980
|
+
ts_vector = func.to_tsvector(self.content_language, self.table.c.content)
|
|
981
|
+
# Create the ts_query using websearch_to_tsquery with parameter binding
|
|
982
|
+
processed_query = self.enable_prefix_matching(query) if self.prefix_match else query
|
|
983
|
+
ts_query = func.websearch_to_tsquery(self.content_language, bindparam("query", value=processed_query))
|
|
984
|
+
# Compute the text rank
|
|
985
|
+
text_rank = func.ts_rank_cd(ts_vector, ts_query)
|
|
986
|
+
|
|
987
|
+
# Compute the vector similarity score
|
|
988
|
+
if self.distance == Distance.l2:
|
|
989
|
+
# For L2 distance, smaller distances are better
|
|
990
|
+
vector_distance = self.table.c.embedding.l2_distance(query_embedding)
|
|
991
|
+
# Invert and normalize the distance to get a similarity score between 0 and 1
|
|
992
|
+
vector_score = 1 / (1 + vector_distance)
|
|
993
|
+
elif self.distance == Distance.cosine:
|
|
994
|
+
# For cosine distance, smaller distances are better
|
|
995
|
+
vector_distance = self.table.c.embedding.cosine_distance(query_embedding)
|
|
996
|
+
vector_score = 1 / (1 + vector_distance)
|
|
997
|
+
elif self.distance == Distance.max_inner_product:
|
|
998
|
+
# For inner product, higher values are better
|
|
999
|
+
# Assume embeddings are normalized, so inner product ranges from -1 to 1
|
|
1000
|
+
raw_vector_score = self.table.c.embedding.max_inner_product(query_embedding)
|
|
1001
|
+
# Normalize to range [0, 1]
|
|
1002
|
+
vector_score = (raw_vector_score + 1) / 2
|
|
1003
|
+
else:
|
|
1004
|
+
logger.error(f"Unknown distance metric: {self.distance}")
|
|
1005
|
+
return []
|
|
1006
|
+
|
|
1007
|
+
# Apply weights to control the influence of each score
|
|
1008
|
+
# Validate the vector_weight parameter
|
|
1009
|
+
if not 0 <= self.vector_score_weight <= 1:
|
|
1010
|
+
raise ValueError("vector_score_weight must be between 0 and 1")
|
|
1011
|
+
text_rank_weight = 1 - self.vector_score_weight # weight for text rank
|
|
1012
|
+
|
|
1013
|
+
# Combine the scores into a hybrid score
|
|
1014
|
+
hybrid_score = (self.vector_score_weight * vector_score) + (text_rank_weight * text_rank)
|
|
1015
|
+
|
|
1016
|
+
# Build the base statement, including the hybrid score
|
|
1017
|
+
stmt = select(*columns, hybrid_score.label("hybrid_score"))
|
|
1018
|
+
|
|
1019
|
+
# Add the full-text search condition
|
|
1020
|
+
# stmt = stmt.where(ts_vector.op("@@")(ts_query))
|
|
1021
|
+
|
|
1022
|
+
# Apply filters if provided
|
|
1023
|
+
if filters is not None:
|
|
1024
|
+
# Handle dict filters
|
|
1025
|
+
if isinstance(filters, dict):
|
|
1026
|
+
stmt = stmt.where(self.table.c.meta_data.contains(filters))
|
|
1027
|
+
# Handle FilterExpr DSL
|
|
1028
|
+
else:
|
|
1029
|
+
# Convert each DSL expression to SQLAlchemy and AND them together
|
|
1030
|
+
sqlalchemy_conditions = [
|
|
1031
|
+
self._dsl_to_sqlalchemy(f.to_dict() if hasattr(f, "to_dict") else f, self.table)
|
|
1032
|
+
for f in filters
|
|
1033
|
+
]
|
|
1034
|
+
stmt = stmt.where(and_(*sqlalchemy_conditions))
|
|
1035
|
+
|
|
1036
|
+
# Order the results by the hybrid score in descending order
|
|
1037
|
+
stmt = stmt.order_by(desc("hybrid_score"))
|
|
1038
|
+
|
|
1039
|
+
# Limit the number of results
|
|
1040
|
+
stmt = stmt.limit(limit)
|
|
1041
|
+
|
|
1042
|
+
# Log the query for debugging
|
|
1043
|
+
log_debug(f"Hybrid search query: {stmt}")
|
|
1044
|
+
|
|
1045
|
+
# Execute the query
|
|
1046
|
+
try:
|
|
1047
|
+
with self.Session() as sess, sess.begin():
|
|
1048
|
+
if self.vector_index is not None:
|
|
1049
|
+
if isinstance(self.vector_index, Ivfflat):
|
|
1050
|
+
sess.execute(text(f"SET LOCAL ivfflat.probes = {self.vector_index.probes}"))
|
|
1051
|
+
elif isinstance(self.vector_index, HNSW):
|
|
1052
|
+
sess.execute(text(f"SET LOCAL hnsw.ef_search = {self.vector_index.ef_search}"))
|
|
1053
|
+
results = sess.execute(stmt).fetchall()
|
|
1054
|
+
except Exception as e:
|
|
1055
|
+
logger.error(f"Error performing hybrid search: {e}")
|
|
1056
|
+
return []
|
|
1057
|
+
|
|
1058
|
+
# Process the results and convert to Document objects
|
|
1059
|
+
search_results: List[Document] = []
|
|
1060
|
+
for result in results:
|
|
1061
|
+
search_results.append(
|
|
1062
|
+
Document(
|
|
1063
|
+
id=result.id,
|
|
1064
|
+
name=result.name,
|
|
1065
|
+
meta_data=result.meta_data,
|
|
1066
|
+
content=result.content,
|
|
1067
|
+
embedder=self.embedder,
|
|
1068
|
+
embedding=result.embedding,
|
|
1069
|
+
usage=result.usage,
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
if self.reranker:
|
|
1074
|
+
search_results = self.reranker.rerank(query=query, documents=search_results)
|
|
1075
|
+
|
|
1076
|
+
log_info(f"Found {len(search_results)} documents")
|
|
1077
|
+
return search_results
|
|
1078
|
+
except Exception as e:
|
|
1079
|
+
logger.error(f"Error during hybrid search: {e}")
|
|
1080
|
+
return []
|
|
1081
|
+
|
|
1082
|
+
def drop(self) -> None:
|
|
1083
|
+
"""
|
|
1084
|
+
Drop the table from the database.
|
|
1085
|
+
"""
|
|
1086
|
+
if self.table_exists():
|
|
1087
|
+
try:
|
|
1088
|
+
log_debug(f"Dropping table '{self.table.fullname}'.")
|
|
1089
|
+
self.table.drop(self.db_engine)
|
|
1090
|
+
log_info(f"Table '{self.table.fullname}' dropped successfully.")
|
|
1091
|
+
except Exception as e:
|
|
1092
|
+
logger.error(f"Error dropping table '{self.table.fullname}': {e}")
|
|
1093
|
+
raise
|
|
1094
|
+
else:
|
|
1095
|
+
log_info(f"Table '{self.table.fullname}' does not exist.")
|
|
1096
|
+
|
|
1097
|
+
async def async_drop(self) -> None:
|
|
1098
|
+
"""Drop the table asynchronously by running in a thread."""
|
|
1099
|
+
await asyncio.to_thread(self.drop)
|
|
1100
|
+
|
|
1101
|
+
def exists(self) -> bool:
|
|
1102
|
+
"""
|
|
1103
|
+
Check if the table exists in the database.
|
|
1104
|
+
|
|
1105
|
+
Returns:
|
|
1106
|
+
bool: True if the table exists, False otherwise.
|
|
1107
|
+
"""
|
|
1108
|
+
return self.table_exists()
|
|
1109
|
+
|
|
1110
|
+
async def async_exists(self) -> bool:
|
|
1111
|
+
"""Check if table exists asynchronously by running in a thread."""
|
|
1112
|
+
return await asyncio.to_thread(self.exists)
|
|
1113
|
+
|
|
1114
|
+
def get_count(self) -> int:
|
|
1115
|
+
"""
|
|
1116
|
+
Get the number of records in the table.
|
|
1117
|
+
|
|
1118
|
+
Returns:
|
|
1119
|
+
int: The number of records in the table.
|
|
1120
|
+
"""
|
|
1121
|
+
try:
|
|
1122
|
+
with self.Session() as sess, sess.begin():
|
|
1123
|
+
stmt = select(func.count(self.table.c.name)).select_from(self.table)
|
|
1124
|
+
result = sess.execute(stmt).scalar()
|
|
1125
|
+
return int(result) if result is not None else 0
|
|
1126
|
+
except Exception as e:
|
|
1127
|
+
logger.error(f"Error getting count from table '{self.table.fullname}': {e}")
|
|
1128
|
+
return 0
|
|
1129
|
+
|
|
1130
|
+
def optimize(self, force_recreate: bool = False) -> None:
|
|
1131
|
+
"""
|
|
1132
|
+
Optimize the vector database by creating or recreating necessary indexes.
|
|
1133
|
+
|
|
1134
|
+
Args:
|
|
1135
|
+
force_recreate (bool): If True, existing indexes will be dropped and recreated.
|
|
1136
|
+
"""
|
|
1137
|
+
log_debug("==== Optimizing Vector DB ====")
|
|
1138
|
+
self._create_vector_index(force_recreate=force_recreate)
|
|
1139
|
+
self._create_gin_index(force_recreate=force_recreate)
|
|
1140
|
+
log_debug("==== Optimized Vector DB ====")
|
|
1141
|
+
|
|
1142
|
+
def _index_exists(self, index_name: str) -> bool:
|
|
1143
|
+
"""
|
|
1144
|
+
Check if an index with the given name exists.
|
|
1145
|
+
|
|
1146
|
+
Args:
|
|
1147
|
+
index_name (str): The name of the index to check.
|
|
1148
|
+
|
|
1149
|
+
Returns:
|
|
1150
|
+
bool: True if the index exists, False otherwise.
|
|
1151
|
+
"""
|
|
1152
|
+
inspector = inspect(self.db_engine)
|
|
1153
|
+
indexes = inspector.get_indexes(self.table.name, schema=self.schema)
|
|
1154
|
+
return any(idx["name"] == index_name for idx in indexes)
|
|
1155
|
+
|
|
1156
|
+
def _drop_index(self, index_name: str) -> None:
|
|
1157
|
+
"""
|
|
1158
|
+
Drop the index with the given name.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
index_name (str): The name of the index to drop.
|
|
1162
|
+
"""
|
|
1163
|
+
try:
|
|
1164
|
+
with self.Session() as sess, sess.begin():
|
|
1165
|
+
drop_index_sql = f'DROP INDEX IF EXISTS "{self.schema}"."{index_name}";'
|
|
1166
|
+
sess.execute(text(drop_index_sql))
|
|
1167
|
+
except Exception as e:
|
|
1168
|
+
logger.error(f"Error dropping index '{index_name}': {e}")
|
|
1169
|
+
raise
|
|
1170
|
+
|
|
1171
|
+
def _create_vector_index(self, force_recreate: bool = False) -> None:
|
|
1172
|
+
"""
|
|
1173
|
+
Create or recreate the vector index.
|
|
1174
|
+
|
|
1175
|
+
Args:
|
|
1176
|
+
force_recreate (bool): If True, existing index will be dropped and recreated.
|
|
1177
|
+
"""
|
|
1178
|
+
if self.vector_index is None:
|
|
1179
|
+
log_debug("No vector index specified, skipping vector index optimization.")
|
|
1180
|
+
return
|
|
1181
|
+
|
|
1182
|
+
# Generate index name if not provided
|
|
1183
|
+
if self.vector_index.name is None:
|
|
1184
|
+
index_type = "ivfflat" if isinstance(self.vector_index, Ivfflat) else "hnsw"
|
|
1185
|
+
self.vector_index.name = f"{self.table_name}_{index_type}_index"
|
|
1186
|
+
|
|
1187
|
+
# Determine index distance operator
|
|
1188
|
+
index_distance = {
|
|
1189
|
+
Distance.l2: "vector_l2_ops",
|
|
1190
|
+
Distance.max_inner_product: "vector_ip_ops",
|
|
1191
|
+
Distance.cosine: "vector_cosine_ops",
|
|
1192
|
+
}.get(self.distance, "vector_cosine_ops")
|
|
1193
|
+
|
|
1194
|
+
# Get the fully qualified table name
|
|
1195
|
+
table_fullname = self.table.fullname # includes schema if any
|
|
1196
|
+
|
|
1197
|
+
# Check if vector index already exists
|
|
1198
|
+
vector_index_exists = self._index_exists(self.vector_index.name)
|
|
1199
|
+
|
|
1200
|
+
if vector_index_exists:
|
|
1201
|
+
log_info(f"Vector index '{self.vector_index.name}' already exists.")
|
|
1202
|
+
if force_recreate:
|
|
1203
|
+
log_info(f"Force recreating vector index '{self.vector_index.name}'. Dropping existing index.")
|
|
1204
|
+
self._drop_index(self.vector_index.name)
|
|
1205
|
+
else:
|
|
1206
|
+
log_info(f"Skipping vector index creation as index '{self.vector_index.name}' already exists.")
|
|
1207
|
+
return
|
|
1208
|
+
|
|
1209
|
+
# Proceed to create the vector index
|
|
1210
|
+
try:
|
|
1211
|
+
with self.Session() as sess, sess.begin():
|
|
1212
|
+
# Set configuration parameters
|
|
1213
|
+
if self.vector_index.configuration:
|
|
1214
|
+
log_debug(f"Setting configuration: {self.vector_index.configuration}")
|
|
1215
|
+
for key, value in self.vector_index.configuration.items():
|
|
1216
|
+
sess.execute(text(f"SET {key} = :value;"), {"value": value})
|
|
1217
|
+
|
|
1218
|
+
if isinstance(self.vector_index, Ivfflat):
|
|
1219
|
+
self._create_ivfflat_index(sess, table_fullname, index_distance)
|
|
1220
|
+
elif isinstance(self.vector_index, HNSW):
|
|
1221
|
+
self._create_hnsw_index(sess, table_fullname, index_distance)
|
|
1222
|
+
else:
|
|
1223
|
+
logger.error(f"Unknown index type: {type(self.vector_index)}")
|
|
1224
|
+
return
|
|
1225
|
+
except Exception as e:
|
|
1226
|
+
logger.error(f"Error creating vector index '{self.vector_index.name}': {e}")
|
|
1227
|
+
raise
|
|
1228
|
+
|
|
1229
|
+
def _create_ivfflat_index(self, sess: Session, table_fullname: str, index_distance: str) -> None:
|
|
1230
|
+
"""
|
|
1231
|
+
Create an IVFFlat index.
|
|
1232
|
+
|
|
1233
|
+
Args:
|
|
1234
|
+
sess (Session): SQLAlchemy session.
|
|
1235
|
+
table_fullname (str): Fully qualified table name.
|
|
1236
|
+
index_distance (str): Distance metric for the index.
|
|
1237
|
+
"""
|
|
1238
|
+
# Cast index to Ivfflat for type hinting
|
|
1239
|
+
self.vector_index = cast(Ivfflat, self.vector_index)
|
|
1240
|
+
|
|
1241
|
+
# Determine number of lists
|
|
1242
|
+
num_lists = self.vector_index.lists
|
|
1243
|
+
if self.vector_index.dynamic_lists:
|
|
1244
|
+
total_records = self.get_count()
|
|
1245
|
+
log_debug(f"Number of records: {total_records}")
|
|
1246
|
+
if total_records < 1000000:
|
|
1247
|
+
num_lists = max(int(total_records / 1000), 1) # Ensure at least one list
|
|
1248
|
+
else:
|
|
1249
|
+
num_lists = max(int(sqrt(total_records)), 1)
|
|
1250
|
+
|
|
1251
|
+
# Set ivfflat.probes
|
|
1252
|
+
sess.execute(text("SET ivfflat.probes = :probes;"), {"probes": self.vector_index.probes})
|
|
1253
|
+
|
|
1254
|
+
log_debug(
|
|
1255
|
+
f"Creating Ivfflat index '{self.vector_index.name}' on table '{table_fullname}' with "
|
|
1256
|
+
f"lists: {num_lists}, probes: {self.vector_index.probes}, "
|
|
1257
|
+
f"and distance metric: {index_distance}"
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
# Create index
|
|
1261
|
+
create_index_sql = text(
|
|
1262
|
+
f'CREATE INDEX "{self.vector_index.name}" ON {table_fullname} '
|
|
1263
|
+
f"USING ivfflat (embedding {index_distance}) "
|
|
1264
|
+
f"WITH (lists = :num_lists);"
|
|
1265
|
+
)
|
|
1266
|
+
sess.execute(create_index_sql, {"num_lists": num_lists})
|
|
1267
|
+
|
|
1268
|
+
def _create_hnsw_index(self, sess: Session, table_fullname: str, index_distance: str) -> None:
|
|
1269
|
+
"""
|
|
1270
|
+
Create an HNSW index.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
sess (Session): SQLAlchemy session.
|
|
1274
|
+
table_fullname (str): Fully qualified table name.
|
|
1275
|
+
index_distance (str): Distance metric for the index.
|
|
1276
|
+
"""
|
|
1277
|
+
# Cast index to HNSW for type hinting
|
|
1278
|
+
self.vector_index = cast(HNSW, self.vector_index)
|
|
1279
|
+
|
|
1280
|
+
log_debug(
|
|
1281
|
+
f"Creating HNSW index '{self.vector_index.name}' on table '{table_fullname}' with "
|
|
1282
|
+
f"m: {self.vector_index.m}, ef_construction: {self.vector_index.ef_construction}, "
|
|
1283
|
+
f"and distance metric: {index_distance}"
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
# Create index
|
|
1287
|
+
create_index_sql = text(
|
|
1288
|
+
f'CREATE INDEX "{self.vector_index.name}" ON {table_fullname} '
|
|
1289
|
+
f"USING hnsw (embedding {index_distance}) "
|
|
1290
|
+
f"WITH (m = :m, ef_construction = :ef_construction);"
|
|
1291
|
+
)
|
|
1292
|
+
sess.execute(create_index_sql, {"m": self.vector_index.m, "ef_construction": self.vector_index.ef_construction})
|
|
1293
|
+
|
|
1294
|
+
def _create_gin_index(self, force_recreate: bool = False) -> None:
|
|
1295
|
+
"""
|
|
1296
|
+
Create or recreate the GIN index for full-text search.
|
|
1297
|
+
|
|
1298
|
+
Args:
|
|
1299
|
+
force_recreate (bool): If True, existing index will be dropped and recreated.
|
|
1300
|
+
"""
|
|
1301
|
+
gin_index_name = f"{self.table_name}_content_gin_index"
|
|
1302
|
+
|
|
1303
|
+
gin_index_exists = self._index_exists(gin_index_name)
|
|
1304
|
+
|
|
1305
|
+
if gin_index_exists:
|
|
1306
|
+
log_info(f"GIN index '{gin_index_name}' already exists.")
|
|
1307
|
+
if force_recreate:
|
|
1308
|
+
log_info(f"Force recreating GIN index '{gin_index_name}'. Dropping existing index.")
|
|
1309
|
+
self._drop_index(gin_index_name)
|
|
1310
|
+
else:
|
|
1311
|
+
log_info(f"Skipping GIN index creation as index '{gin_index_name}' already exists.")
|
|
1312
|
+
return
|
|
1313
|
+
|
|
1314
|
+
# Proceed to create GIN index
|
|
1315
|
+
try:
|
|
1316
|
+
with self.Session() as sess, sess.begin():
|
|
1317
|
+
log_debug(f"Creating GIN index '{gin_index_name}' on table '{self.table.fullname}'.")
|
|
1318
|
+
# Create index
|
|
1319
|
+
create_gin_index_sql = text(
|
|
1320
|
+
f'CREATE INDEX "{gin_index_name}" ON {self.table.fullname} '
|
|
1321
|
+
f"USING GIN (to_tsvector({self.content_language}, content));"
|
|
1322
|
+
)
|
|
1323
|
+
sess.execute(create_gin_index_sql)
|
|
1324
|
+
except Exception as e:
|
|
1325
|
+
logger.error(f"Error creating GIN index '{gin_index_name}': {e}")
|
|
1326
|
+
raise
|
|
1327
|
+
|
|
1328
|
+
def delete(self) -> bool:
|
|
1329
|
+
"""
|
|
1330
|
+
Delete all records from the table.
|
|
1331
|
+
|
|
1332
|
+
Returns:
|
|
1333
|
+
bool: True if deletion was successful, False otherwise.
|
|
1334
|
+
"""
|
|
1335
|
+
from sqlalchemy import delete
|
|
1336
|
+
|
|
1337
|
+
try:
|
|
1338
|
+
with self.Session() as sess:
|
|
1339
|
+
sess.execute(delete(self.table))
|
|
1340
|
+
sess.commit()
|
|
1341
|
+
log_info(f"Deleted all records from table '{self.table.fullname}'.")
|
|
1342
|
+
return True
|
|
1343
|
+
except Exception as e:
|
|
1344
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1345
|
+
sess.rollback()
|
|
1346
|
+
return False
|
|
1347
|
+
|
|
1348
|
+
def delete_by_id(self, id: str) -> bool:
|
|
1349
|
+
"""
|
|
1350
|
+
Delete content by ID.
|
|
1351
|
+
"""
|
|
1352
|
+
try:
|
|
1353
|
+
with self.Session() as sess, sess.begin():
|
|
1354
|
+
stmt = self.table.delete().where(self.table.c.id == id)
|
|
1355
|
+
sess.execute(stmt)
|
|
1356
|
+
sess.commit()
|
|
1357
|
+
log_info(f"Deleted records with id '{id}' from table '{self.table.fullname}'.")
|
|
1358
|
+
return True
|
|
1359
|
+
except Exception as e:
|
|
1360
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1361
|
+
sess.rollback()
|
|
1362
|
+
return False
|
|
1363
|
+
|
|
1364
|
+
def delete_by_name(self, name: str) -> bool:
|
|
1365
|
+
"""
|
|
1366
|
+
Delete content by name.
|
|
1367
|
+
"""
|
|
1368
|
+
try:
|
|
1369
|
+
with self.Session() as sess, sess.begin():
|
|
1370
|
+
stmt = self.table.delete().where(self.table.c.name == name)
|
|
1371
|
+
sess.execute(stmt)
|
|
1372
|
+
sess.commit()
|
|
1373
|
+
log_info(f"Deleted records with name '{name}' from table '{self.table.fullname}'.")
|
|
1374
|
+
return True
|
|
1375
|
+
except Exception as e:
|
|
1376
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1377
|
+
sess.rollback()
|
|
1378
|
+
return False
|
|
1379
|
+
|
|
1380
|
+
def delete_by_metadata(self, metadata: Dict[str, Any]) -> bool:
|
|
1381
|
+
"""
|
|
1382
|
+
Delete content by metadata.
|
|
1383
|
+
"""
|
|
1384
|
+
try:
|
|
1385
|
+
with self.Session() as sess, sess.begin():
|
|
1386
|
+
stmt = self.table.delete().where(self.table.c.meta_data.contains(metadata))
|
|
1387
|
+
sess.execute(stmt)
|
|
1388
|
+
sess.commit()
|
|
1389
|
+
log_info(f"Deleted records with metadata '{metadata}' from table '{self.table.fullname}'.")
|
|
1390
|
+
return True
|
|
1391
|
+
except Exception as e:
|
|
1392
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1393
|
+
sess.rollback()
|
|
1394
|
+
return False
|
|
1395
|
+
|
|
1396
|
+
def delete_by_content_id(self, content_id: str) -> bool:
|
|
1397
|
+
"""
|
|
1398
|
+
Delete content by content ID.
|
|
1399
|
+
"""
|
|
1400
|
+
try:
|
|
1401
|
+
with self.Session() as sess, sess.begin():
|
|
1402
|
+
stmt = self.table.delete().where(self.table.c.content_id == content_id)
|
|
1403
|
+
sess.execute(stmt)
|
|
1404
|
+
sess.commit()
|
|
1405
|
+
log_info(f"Deleted records with content ID '{content_id}' from table '{self.table.fullname}'.")
|
|
1406
|
+
return True
|
|
1407
|
+
except Exception as e:
|
|
1408
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1409
|
+
sess.rollback()
|
|
1410
|
+
return False
|
|
1411
|
+
|
|
1412
|
+
def _delete_by_content_hash(self, content_hash: str) -> bool:
|
|
1413
|
+
"""
|
|
1414
|
+
Delete content by content hash.
|
|
1415
|
+
"""
|
|
1416
|
+
try:
|
|
1417
|
+
with self.Session() as sess, sess.begin():
|
|
1418
|
+
stmt = self.table.delete().where(self.table.c.content_hash == content_hash)
|
|
1419
|
+
sess.execute(stmt)
|
|
1420
|
+
sess.commit()
|
|
1421
|
+
log_info(f"Deleted records with content hash '{content_hash}' from table '{self.table.fullname}'.")
|
|
1422
|
+
return True
|
|
1423
|
+
except Exception as e:
|
|
1424
|
+
logger.error(f"Error deleting rows from table '{self.table.fullname}': {e}")
|
|
1425
|
+
sess.rollback()
|
|
1426
|
+
return False
|
|
1427
|
+
|
|
1428
|
+
def __deepcopy__(self, memo):
|
|
1429
|
+
"""
|
|
1430
|
+
Create a deep copy of the PgVector instance, handling unpickleable attributes.
|
|
1431
|
+
|
|
1432
|
+
Args:
|
|
1433
|
+
memo (dict): A dictionary of objects already copied during the current copying pass.
|
|
1434
|
+
|
|
1435
|
+
Returns:
|
|
1436
|
+
PgVector: A deep-copied instance of PgVector.
|
|
1437
|
+
"""
|
|
1438
|
+
from copy import deepcopy
|
|
1439
|
+
|
|
1440
|
+
# Create a new instance without calling __init__
|
|
1441
|
+
cls = self.__class__
|
|
1442
|
+
copied_obj = cls.__new__(cls)
|
|
1443
|
+
memo[id(self)] = copied_obj
|
|
1444
|
+
|
|
1445
|
+
# Deep copy attributes
|
|
1446
|
+
for k, v in self.__dict__.items():
|
|
1447
|
+
if k in {"metadata", "table"}:
|
|
1448
|
+
continue
|
|
1449
|
+
# Reuse db_engine and Session without copying
|
|
1450
|
+
elif k in {"db_engine", "Session", "embedder"}:
|
|
1451
|
+
setattr(copied_obj, k, v)
|
|
1452
|
+
else:
|
|
1453
|
+
setattr(copied_obj, k, deepcopy(v, memo))
|
|
1454
|
+
|
|
1455
|
+
# Recreate metadata and table for the copied instance
|
|
1456
|
+
copied_obj.metadata = MetaData(schema=copied_obj.schema)
|
|
1457
|
+
copied_obj.table = copied_obj.get_table()
|
|
1458
|
+
|
|
1459
|
+
return copied_obj
|
|
1460
|
+
|
|
1461
|
+
def get_supported_search_types(self) -> List[str]:
|
|
1462
|
+
return [SearchType.vector, SearchType.keyword, SearchType.hybrid]
|