ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,2853 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union, Callable
|
|
2
|
+
import uuid
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
import numpy as np
|
|
5
|
+
import sqlalchemy
|
|
6
|
+
from sqlalchemy import (
|
|
7
|
+
text,
|
|
8
|
+
Column,
|
|
9
|
+
insert,
|
|
10
|
+
Table,
|
|
11
|
+
MetaData,
|
|
12
|
+
select,
|
|
13
|
+
asc,
|
|
14
|
+
func,
|
|
15
|
+
event,
|
|
16
|
+
JSON,
|
|
17
|
+
Index
|
|
18
|
+
)
|
|
19
|
+
from sqlalchemy.sql import literal_column
|
|
20
|
+
from sqlalchemy import bindparam
|
|
21
|
+
from sqlalchemy.orm import aliased
|
|
22
|
+
from sqlalchemy.ext.asyncio import (
|
|
23
|
+
create_async_engine,
|
|
24
|
+
AsyncSession,
|
|
25
|
+
AsyncEngine,
|
|
26
|
+
async_sessionmaker
|
|
27
|
+
)
|
|
28
|
+
from sqlalchemy.sql.expression import cast
|
|
29
|
+
from sqlalchemy.dialects.postgresql import JSONB, ARRAY
|
|
30
|
+
from sqlalchemy.orm import (
|
|
31
|
+
declarative_base,
|
|
32
|
+
DeclarativeBase,
|
|
33
|
+
Mapped,
|
|
34
|
+
mapped_column
|
|
35
|
+
)
|
|
36
|
+
# PgVector
|
|
37
|
+
from pgvector.sqlalchemy import Vector
|
|
38
|
+
from pgvector.asyncpg import register_vector
|
|
39
|
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
|
40
|
+
# Datamodel
|
|
41
|
+
from datamodel.parsers.json import json_encoder # pylint: disable=E0611
|
|
42
|
+
from navconfig.logging import logging
|
|
43
|
+
from .abstract import AbstractStore
|
|
44
|
+
from ..conf import default_sqlalchemy_pg
|
|
45
|
+
from .models import SearchResult, Document, DistanceStrategy
|
|
46
|
+
from .utils.chunking import LateChunkingProcessor
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def vector_distance(embedding_column, vector, op):
|
|
50
|
+
return text(f"{embedding_column} {op} :query_embedding").label("distance")
|
|
51
|
+
|
|
52
|
+
class Base(DeclarativeBase):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
class PgVectorStore(AbstractStore):
|
|
56
|
+
"""
|
|
57
|
+
A PostgreSQL vector store implementation using pgvector, completely independent of Langchain.
|
|
58
|
+
This store interacts directly with a specified schema and table for robust data isolation.
|
|
59
|
+
"""
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
table: str = None,
|
|
63
|
+
schema: str = 'public',
|
|
64
|
+
id_column: str = 'id',
|
|
65
|
+
embedding_column: str = 'embedding',
|
|
66
|
+
document_column: str = 'document',
|
|
67
|
+
text_column: str = 'text',
|
|
68
|
+
embedding_model: Union[dict, str] = "sentence-transformers/all-mpnet-base-v2",
|
|
69
|
+
embedding: Optional[Callable] = None,
|
|
70
|
+
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
|
71
|
+
use_uuid: bool = False,
|
|
72
|
+
pool_size: int = 50,
|
|
73
|
+
auto_initialize: bool = True,
|
|
74
|
+
**kwargs
|
|
75
|
+
):
|
|
76
|
+
""" Initializes the PgVectorStore with the specified parameters.
|
|
77
|
+
"""
|
|
78
|
+
self.table_name = table
|
|
79
|
+
self.schema = schema
|
|
80
|
+
self._id_column: str = id_column
|
|
81
|
+
self._embedding_column: str = embedding_column
|
|
82
|
+
self._document_column: str = document_column
|
|
83
|
+
self._text_column: str = text_column
|
|
84
|
+
self.distance_strategy = distance_strategy
|
|
85
|
+
self._use_uuid: bool = use_uuid
|
|
86
|
+
self._embedding_store_cache: Dict[str, Any] = {}
|
|
87
|
+
self._max_size = pool_size or 50
|
|
88
|
+
self._auto_initialize_db: bool = auto_initialize
|
|
89
|
+
super().__init__(
|
|
90
|
+
embedding_model=embedding_model,
|
|
91
|
+
embedding=embedding,
|
|
92
|
+
**kwargs
|
|
93
|
+
)
|
|
94
|
+
self.dsn = kwargs.get('dsn', default_sqlalchemy_pg)
|
|
95
|
+
self._connection: Optional[AsyncEngine] = None
|
|
96
|
+
self._session_factory: Optional[async_sessionmaker] = None
|
|
97
|
+
self._session: Optional[AsyncSession] = None
|
|
98
|
+
self.logger = logging.getLogger("PgVectorStore")
|
|
99
|
+
self.embedding_store = None
|
|
100
|
+
if table:
|
|
101
|
+
# create a table definition:
|
|
102
|
+
self.embedding_store = self._define_collection_store(
|
|
103
|
+
table=table,
|
|
104
|
+
schema=schema,
|
|
105
|
+
dimension=self.dimension,
|
|
106
|
+
id_column=id_column,
|
|
107
|
+
embedding_column=embedding_column,
|
|
108
|
+
document_column=self._document_column,
|
|
109
|
+
text_column=text_column,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def get_id_column(self, use_uuid: bool) -> sqlalchemy.Column:
|
|
113
|
+
"""
|
|
114
|
+
Return the ID column definition based on whether to use UUID or not.
|
|
115
|
+
If use_uuid is True, the ID column will be a PostgreSQL UUID type with
|
|
116
|
+
server-side generation using uuid_generate_v4().
|
|
117
|
+
If use_uuid is False, the ID column will be a String type with a default
|
|
118
|
+
value generated by Python's uuid.uuid4() function.
|
|
119
|
+
"""
|
|
120
|
+
if use_uuid:
|
|
121
|
+
# DB will auto-generate UUID; SQLAlchemy should not set a default!
|
|
122
|
+
return sqlalchemy.Column(
|
|
123
|
+
sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
|
|
124
|
+
primary_key=True,
|
|
125
|
+
index=True,
|
|
126
|
+
unique=True,
|
|
127
|
+
server_default=sqlalchemy.text('uuid_generate_v4()')
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
# Python generates UUID (as string)
|
|
131
|
+
return sqlalchemy.Column(
|
|
132
|
+
sqlalchemy.String,
|
|
133
|
+
primary_key=True,
|
|
134
|
+
index=True,
|
|
135
|
+
unique=True,
|
|
136
|
+
default=lambda: str(uuid.uuid4())
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _define_collection_store(
|
|
140
|
+
self,
|
|
141
|
+
table: str,
|
|
142
|
+
schema: str,
|
|
143
|
+
dimension: int = 384,
|
|
144
|
+
id_column: str = 'id',
|
|
145
|
+
embedding_column: str = 'embedding',
|
|
146
|
+
document_column: str = 'document',
|
|
147
|
+
metadata_column: str = 'cmetadata',
|
|
148
|
+
text_column: str = 'text',
|
|
149
|
+
store_name: str = 'EmbeddingStore',
|
|
150
|
+
colbert_dimension: int = 128 # ColBERT token dimension
|
|
151
|
+
) -> Any:
|
|
152
|
+
"""Dynamically define a SQLAlchemy Table for pgvector storage.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
table: The name of the table to create.
|
|
156
|
+
schema: The schema in which to create the table.
|
|
157
|
+
dimension: The dimensionality of the vector embeddings.
|
|
158
|
+
"""
|
|
159
|
+
fq_table_name = f"{schema}.{table}"
|
|
160
|
+
if fq_table_name in self._embedding_store_cache:
|
|
161
|
+
return self._embedding_store_cache[fq_table_name]
|
|
162
|
+
|
|
163
|
+
self.logger.notice(
|
|
164
|
+
f"Defining dynamic ORM class for table {fq_table_name} with dimension {dimension}"
|
|
165
|
+
)
|
|
166
|
+
table_args = {
|
|
167
|
+
"schema": schema,
|
|
168
|
+
"extend_existing": True
|
|
169
|
+
}
|
|
170
|
+
attrs = {
|
|
171
|
+
'__tablename__': table,
|
|
172
|
+
'__table_args__': table_args,
|
|
173
|
+
# id_column: self.get_id_column(use_uuid=self._use_uuid),
|
|
174
|
+
id_column: mapped_column(
|
|
175
|
+
sqlalchemy.String if not self._use_uuid else sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
|
|
176
|
+
primary_key=True,
|
|
177
|
+
index=True,
|
|
178
|
+
unique=True,
|
|
179
|
+
default=lambda: str(uuid.uuid4()) if not self._use_uuid else None,
|
|
180
|
+
server_default=sqlalchemy.text('uuid_generate_v4()') if self._use_uuid else None
|
|
181
|
+
),
|
|
182
|
+
embedding_column: mapped_column(Vector(dimension)),
|
|
183
|
+
text_column: mapped_column(sqlalchemy.String, nullable=True),
|
|
184
|
+
document_column: mapped_column(sqlalchemy.String, nullable=True),
|
|
185
|
+
metadata_column: mapped_column(JSONB, nullable=True),
|
|
186
|
+
|
|
187
|
+
# embedding_column: Column(Vector(dimension)),
|
|
188
|
+
# text_column: Column(sqlalchemy.String, nullable=True),
|
|
189
|
+
# document_column: Column(sqlalchemy.String, nullable=True),
|
|
190
|
+
# metadata_column: Column(JSONB, nullable=True),
|
|
191
|
+
# ColBERT columns
|
|
192
|
+
# 'token_embeddings': Column(ARRAY(Vector(colbert_dimension)), nullable=True),
|
|
193
|
+
# 'num_tokens': Column(sqlalchemy.Integer, nullable=True),
|
|
194
|
+
# 'collection_id': Column(
|
|
195
|
+
# sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
|
|
196
|
+
# index=True,
|
|
197
|
+
# unique=True,
|
|
198
|
+
# default=uuid.uuid4,
|
|
199
|
+
# server_default=sqlalchemy.text('uuid_generate_v4()')
|
|
200
|
+
# )
|
|
201
|
+
'token_embeddings': mapped_column(ARRAY(Vector(colbert_dimension)), nullable=True),
|
|
202
|
+
'num_tokens': mapped_column(sqlalchemy.Integer, nullable=True),
|
|
203
|
+
'collection_id': mapped_column(
|
|
204
|
+
sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
|
|
205
|
+
index=True,
|
|
206
|
+
unique=True,
|
|
207
|
+
default=uuid.uuid4,
|
|
208
|
+
server_default=sqlalchemy.text('uuid_generate_v4()')
|
|
209
|
+
)
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
# Create dynamic ORM class
|
|
213
|
+
EmbeddingStore = type(store_name, (Base,), attrs)
|
|
214
|
+
EmbeddingStore.__name__ = store_name
|
|
215
|
+
EmbeddingStore.__qualname__ = store_name
|
|
216
|
+
|
|
217
|
+
# Cache the store
|
|
218
|
+
self._embedding_store_cache[fq_table_name] = EmbeddingStore
|
|
219
|
+
self.logger.debug(
|
|
220
|
+
f"Created dynamic ORM class {store_name} for table {fq_table_name}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return EmbeddingStore
|
|
224
|
+
|
|
225
|
+
def define_collection_table(
|
|
226
|
+
self,
|
|
227
|
+
table: str,
|
|
228
|
+
schema: str,
|
|
229
|
+
dimension: int = 384,
|
|
230
|
+
metadata: Optional[MetaData] = None,
|
|
231
|
+
use_uuid: bool = False,
|
|
232
|
+
id_column: str = 'id',
|
|
233
|
+
embedding_column: str = 'embedding'
|
|
234
|
+
) -> sqlalchemy.Table:
|
|
235
|
+
"""Dynamically define a SQLAlchemy Table for pgvector storage."""
|
|
236
|
+
columns = []
|
|
237
|
+
|
|
238
|
+
if use_uuid:
|
|
239
|
+
columns.append(Column(
|
|
240
|
+
id_column,
|
|
241
|
+
sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
|
|
242
|
+
primary_key=True,
|
|
243
|
+
server_default=sqlalchemy.text("uuid_generate_v4()")
|
|
244
|
+
))
|
|
245
|
+
else:
|
|
246
|
+
columns.append(Column(
|
|
247
|
+
id_column,
|
|
248
|
+
sqlalchemy.String,
|
|
249
|
+
primary_key=True,
|
|
250
|
+
default=lambda: str(uuid.uuid4())
|
|
251
|
+
))
|
|
252
|
+
|
|
253
|
+
columns.extend([
|
|
254
|
+
Column(embedding_column, Vector(dimension)),
|
|
255
|
+
Column('text', sqlalchemy.String, nullable=True),
|
|
256
|
+
Column('document', sqlalchemy.String, nullable=True),
|
|
257
|
+
Column('cmetadata', JSONB, nullable=True)
|
|
258
|
+
])
|
|
259
|
+
|
|
260
|
+
return Table(
|
|
261
|
+
table,
|
|
262
|
+
metadata,
|
|
263
|
+
*columns,
|
|
264
|
+
schema=schema
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
async def connection(self, dsn: str = None) -> AsyncEngine:
|
|
268
|
+
"""Establishes and returns an async database connection."""
|
|
269
|
+
if self._connection is not None:
|
|
270
|
+
return self._connection
|
|
271
|
+
if not dsn:
|
|
272
|
+
dsn = self.dsn or default_sqlalchemy_pg
|
|
273
|
+
try:
|
|
274
|
+
self._connection = create_async_engine(
|
|
275
|
+
dsn,
|
|
276
|
+
future=True,
|
|
277
|
+
pool_size=self._max_size, # High concurrency support
|
|
278
|
+
max_overflow=100, # Burst capacity
|
|
279
|
+
pool_pre_ping=True, # Connection health checks
|
|
280
|
+
pool_recycle=3600, # Prevent stale connections (1 hour)
|
|
281
|
+
pool_timeout=30, # Wait max 30s for connection
|
|
282
|
+
connect_args={
|
|
283
|
+
"server_settings": {
|
|
284
|
+
"jit": "off", # Disable JIT for vector queries
|
|
285
|
+
"random_page_cost": "1.1", # SSD optimization
|
|
286
|
+
"effective_cache_size": "24GB", # Memory configuration
|
|
287
|
+
"work_mem": "256MB"
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
)
|
|
291
|
+
# @event.listens_for(self._connection.sync_engine, "first_connect")
|
|
292
|
+
# def connect(dbapi_connection, connection_record):
|
|
293
|
+
# dbapi_connection.run_async(register_vector)
|
|
294
|
+
|
|
295
|
+
# Create session factory
|
|
296
|
+
self._session_factory = async_sessionmaker(
|
|
297
|
+
bind=self._connection,
|
|
298
|
+
class_=AsyncSession,
|
|
299
|
+
expire_on_commit=False,
|
|
300
|
+
autoflush=False, # Manual control over flushing
|
|
301
|
+
autocommit=False
|
|
302
|
+
)
|
|
303
|
+
if self._auto_initialize_db:
|
|
304
|
+
await self.initialize_database()
|
|
305
|
+
self._connected = True
|
|
306
|
+
self.logger.info(
|
|
307
|
+
"Successfully connected to PostgreSQL."
|
|
308
|
+
)
|
|
309
|
+
except Exception as e:
|
|
310
|
+
self.logger.error(
|
|
311
|
+
f"Failed to connect to PostgreSQL: {e}"
|
|
312
|
+
)
|
|
313
|
+
self._connected = False
|
|
314
|
+
raise
|
|
315
|
+
|
|
316
|
+
async def get_session(self) -> AsyncSession:
|
|
317
|
+
"""Get a session from the pool. This is the main method for getting connections."""
|
|
318
|
+
if not self._connection:
|
|
319
|
+
await self.connection()
|
|
320
|
+
|
|
321
|
+
if not self._session_factory:
|
|
322
|
+
raise RuntimeError("Session factory not initialized")
|
|
323
|
+
|
|
324
|
+
return self._session_factory()
|
|
325
|
+
|
|
326
|
+
@asynccontextmanager
|
|
327
|
+
async def session(self):
|
|
328
|
+
"""
|
|
329
|
+
Context manager for handling database sessions with proper cleanup.
|
|
330
|
+
This is the recommended way to handle database operations.
|
|
331
|
+
|
|
332
|
+
Usage:
|
|
333
|
+
async with store.session() as session:
|
|
334
|
+
result = await session.execute(stmt)
|
|
335
|
+
await session.commit() # if needed
|
|
336
|
+
"""
|
|
337
|
+
if not self._connection:
|
|
338
|
+
await self.connection()
|
|
339
|
+
|
|
340
|
+
session = await self.get_session()
|
|
341
|
+
try:
|
|
342
|
+
yield session
|
|
343
|
+
# Auto-commit if no exception occurred
|
|
344
|
+
if session.in_transaction():
|
|
345
|
+
await session.commit()
|
|
346
|
+
except Exception:
|
|
347
|
+
# Auto-rollback on exception
|
|
348
|
+
if session.in_transaction():
|
|
349
|
+
await session.rollback()
|
|
350
|
+
raise
|
|
351
|
+
finally:
|
|
352
|
+
# Always close session (returns connection to pool)
|
|
353
|
+
await session.close()
|
|
354
|
+
|
|
355
|
+
async def initialize_database(self):
|
|
356
|
+
"""Initialize with PgVector 0.8.0+ optimizations"""
|
|
357
|
+
try:
|
|
358
|
+
async with self.session() as session:
|
|
359
|
+
await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
360
|
+
|
|
361
|
+
# Enable iterative scanning (breakthrough feature)
|
|
362
|
+
await session.execute(text("SET hnsw.iterative_scan = 'relaxed_order'"))
|
|
363
|
+
await session.execute(text("SET hnsw.max_scan_tuples = 20000"))
|
|
364
|
+
await session.execute(text("SET hnsw.ef_search = 200"))
|
|
365
|
+
await session.execute(text("SET ivfflat.iterative_scan = 'on'"))
|
|
366
|
+
await session.execute(text("SET ivfflat.max_probes = 100"))
|
|
367
|
+
|
|
368
|
+
# Performance tuning
|
|
369
|
+
await session.execute(text("SET maintenance_work_mem = '2GB'"))
|
|
370
|
+
await session.execute(text("SET max_parallel_maintenance_workers = 8"))
|
|
371
|
+
await session.execute(text("SET enable_seqscan = off"))
|
|
372
|
+
|
|
373
|
+
# Create ColBERT MaxSim function
|
|
374
|
+
await self._create_maxsim_function(session)
|
|
375
|
+
|
|
376
|
+
await session.commit()
|
|
377
|
+
except Exception as e:
|
|
378
|
+
self.logger.warning(f"⚠️ Database auto-initialization failed: {e}")
|
|
379
|
+
# Don't raise - let the engine continue to work
|
|
380
|
+
|
|
381
|
+
async def _create_maxsim_function(self, session):
|
|
382
|
+
"""Create the MaxSim function for ColBERT late interaction scoring."""
|
|
383
|
+
maxsim_function = text("""
|
|
384
|
+
CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[])
|
|
385
|
+
RETURNS double precision AS $$
|
|
386
|
+
DECLARE
|
|
387
|
+
query_vec vector;
|
|
388
|
+
doc_vec vector;
|
|
389
|
+
max_similarity double precision;
|
|
390
|
+
total_score double precision := 0;
|
|
391
|
+
similarity double precision;
|
|
392
|
+
BEGIN
|
|
393
|
+
-- For each query token, find the maximum similarity with any document token
|
|
394
|
+
FOR i IN 1..array_length(query, 1) LOOP
|
|
395
|
+
query_vec := query[i];
|
|
396
|
+
max_similarity := -1;
|
|
397
|
+
|
|
398
|
+
-- Find max similarity with all document tokens
|
|
399
|
+
FOR j IN 1..array_length(document, 1) LOOP
|
|
400
|
+
doc_vec := document[j];
|
|
401
|
+
similarity := 1 - (query_vec <=> doc_vec); -- Convert distance to similarity
|
|
402
|
+
|
|
403
|
+
IF similarity > max_similarity THEN
|
|
404
|
+
max_similarity := similarity;
|
|
405
|
+
END IF;
|
|
406
|
+
END LOOP;
|
|
407
|
+
|
|
408
|
+
-- Add the maximum similarity for this query token
|
|
409
|
+
total_score := total_score + max_similarity;
|
|
410
|
+
END LOOP;
|
|
411
|
+
|
|
412
|
+
RETURN total_score;
|
|
413
|
+
END;
|
|
414
|
+
$$ LANGUAGE plpgsql IMMUTABLE STRICT;
|
|
415
|
+
""")
|
|
416
|
+
|
|
417
|
+
await session.execute(maxsim_function)
|
|
418
|
+
self.logger.info("✅ Created ColBERT MaxSim function")
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
# Async Context Manager - improved pattern
|
|
422
|
+
async def __aenter__(self):
|
|
423
|
+
"""
|
|
424
|
+
Context manager entry. Ensures engine is initialized and manages session lifecycle.
|
|
425
|
+
"""
|
|
426
|
+
if not self._connection:
|
|
427
|
+
await self.connection()
|
|
428
|
+
|
|
429
|
+
# Create a session for this context if we don't have one
|
|
430
|
+
if self._session is None:
|
|
431
|
+
self._session = await self.get_session()
|
|
432
|
+
|
|
433
|
+
self._context_depth += 1
|
|
434
|
+
return self
|
|
435
|
+
|
|
436
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
437
|
+
"""
|
|
438
|
+
Context manager exit. Properly handles session cleanup.
|
|
439
|
+
"""
|
|
440
|
+
self._context_depth -= 1
|
|
441
|
+
|
|
442
|
+
# Only close session when we exit the outermost context
|
|
443
|
+
if self._context_depth == 0 and self._session:
|
|
444
|
+
try:
|
|
445
|
+
if exc_type is not None:
|
|
446
|
+
# Exception occurred, rollback
|
|
447
|
+
if self._session.in_transaction():
|
|
448
|
+
await self._session.rollback()
|
|
449
|
+
else:
|
|
450
|
+
# No exception, commit if in transaction
|
|
451
|
+
if self._session.in_transaction():
|
|
452
|
+
await self._session.commit()
|
|
453
|
+
finally:
|
|
454
|
+
# Always close the session (returns connection to pool)
|
|
455
|
+
await self._session.close()
|
|
456
|
+
self._session = None
|
|
457
|
+
|
|
458
|
+
async def _free_resources(self):
|
|
459
|
+
"""Clean up resources but keep the engine/pool available."""
|
|
460
|
+
if self._embed_:
|
|
461
|
+
self._embed_.free()
|
|
462
|
+
self._embed_ = None
|
|
463
|
+
|
|
464
|
+
# Close current session if exists
|
|
465
|
+
if self._session:
|
|
466
|
+
await self._session.close()
|
|
467
|
+
self._session = None
|
|
468
|
+
|
|
469
|
+
async def disconnect(self) -> None:
|
|
470
|
+
"""
|
|
471
|
+
Completely dispose of the engine and close all connections.
|
|
472
|
+
Call this when you're completely done with the store.
|
|
473
|
+
"""
|
|
474
|
+
# Close current session first
|
|
475
|
+
if self._session:
|
|
476
|
+
await self._session.close()
|
|
477
|
+
self._session = None
|
|
478
|
+
|
|
479
|
+
# Dispose of the engine (closes all pooled connections)
|
|
480
|
+
if self._connection:
|
|
481
|
+
await self._connection.dispose()
|
|
482
|
+
self._connection = None
|
|
483
|
+
self._connected = False
|
|
484
|
+
self._session_factory = None
|
|
485
|
+
self.logger.info(
|
|
486
|
+
"🔌 PostgreSQL engine disposed and all connections closed"
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
async def add_documents(
|
|
490
|
+
self,
|
|
491
|
+
documents: List[Document],
|
|
492
|
+
table: str = None,
|
|
493
|
+
schema: str = None,
|
|
494
|
+
embedding_column: str = 'embedding',
|
|
495
|
+
content_column: str = 'document',
|
|
496
|
+
metadata_column: str = 'cmetadata',
|
|
497
|
+
**kwargs
|
|
498
|
+
) -> None:
|
|
499
|
+
"""
|
|
500
|
+
Embeds and adds documents to the specified table.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
documents: A list of Document objects to add.
|
|
504
|
+
table: The name of the table.
|
|
505
|
+
schema: The database schema where the table resides.
|
|
506
|
+
embedding_column: The name of the column to store embeddings.
|
|
507
|
+
content_column: The name of the column to store the main text content.
|
|
508
|
+
metadata_column: The name of the JSONB column for metadata.
|
|
509
|
+
"""
|
|
510
|
+
if not self._connected:
|
|
511
|
+
await self.connection()
|
|
512
|
+
|
|
513
|
+
if not table:
|
|
514
|
+
table = self.table_name
|
|
515
|
+
if not schema:
|
|
516
|
+
schema = self.schema
|
|
517
|
+
|
|
518
|
+
texts = [doc.page_content for doc in documents]
|
|
519
|
+
embeddings = self._embed_.embed_documents(texts)
|
|
520
|
+
metadatas = [doc.metadata for doc in documents]
|
|
521
|
+
|
|
522
|
+
# Step 1: Ensure the ORM table is initialized
|
|
523
|
+
if self.embedding_store is None:
|
|
524
|
+
self.embedding_store = self._define_collection_store(
|
|
525
|
+
table=table,
|
|
526
|
+
schema=schema,
|
|
527
|
+
dimension=self.dimension,
|
|
528
|
+
id_column=self._id_column,
|
|
529
|
+
embedding_column=embedding_column,
|
|
530
|
+
document_column=content_column,
|
|
531
|
+
metadata_column=metadata_column,
|
|
532
|
+
text_column=self._text_column,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Step 2: Prepare values for bulk insert
|
|
536
|
+
values = [
|
|
537
|
+
{
|
|
538
|
+
self._id_column: str(uuid.uuid4()),
|
|
539
|
+
embedding_column: embeddings[i].tolist() if isinstance(
|
|
540
|
+
embeddings[i], np.ndarray
|
|
541
|
+
) else embeddings[i],
|
|
542
|
+
content_column: texts[i],
|
|
543
|
+
metadata_column: metadatas[i] or {}
|
|
544
|
+
}
|
|
545
|
+
for i in range(len(documents))
|
|
546
|
+
]
|
|
547
|
+
|
|
548
|
+
# Step 3: Build insert statement using SQLAlchemy's insert()
|
|
549
|
+
insert_stmt = insert(self.embedding_store)
|
|
550
|
+
|
|
551
|
+
# Step 4: Execute using async executemany
|
|
552
|
+
try:
|
|
553
|
+
async with self.session() as session:
|
|
554
|
+
await session.execute(insert_stmt, values)
|
|
555
|
+
self.logger.info(
|
|
556
|
+
f"✅ Successfully added {len(documents)} documents to '{schema}.{table}'"
|
|
557
|
+
)
|
|
558
|
+
except Exception as e:
|
|
559
|
+
self.logger.error(f"Error adding documents: {e}")
|
|
560
|
+
raise
|
|
561
|
+
|
|
562
|
+
def get_distance_strategy(
|
|
563
|
+
self,
|
|
564
|
+
embedding_column_obj,
|
|
565
|
+
query_embedding,
|
|
566
|
+
metric: str = None
|
|
567
|
+
) -> Any:
|
|
568
|
+
"""
|
|
569
|
+
Return the appropriate distance expression based on the metric or configured strategy.
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
embedding_column_obj: The SQLAlchemy column object for embeddings
|
|
573
|
+
query_embedding: The query embedding vector
|
|
574
|
+
metric: Optional metric string ('COSINE', 'L2', 'IP', 'DOT')
|
|
575
|
+
- if None, uses self.distance_strategy
|
|
576
|
+
"""
|
|
577
|
+
# Use provided metric or fall back to instance distance_strategy
|
|
578
|
+
strategy = metric or self.distance_strategy
|
|
579
|
+
# self.logger.debug(
|
|
580
|
+
# f"PgVector: using distance strategy → {strategy}"
|
|
581
|
+
# )
|
|
582
|
+
|
|
583
|
+
# Convert string metrics to DistanceStrategy enum if needed
|
|
584
|
+
if isinstance(strategy, str):
|
|
585
|
+
metric_mapping = {
|
|
586
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
587
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
588
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
589
|
+
'IP': DistanceStrategy.MAX_INNER_PRODUCT,
|
|
590
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
591
|
+
'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
|
|
592
|
+
'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
|
|
593
|
+
}
|
|
594
|
+
strategy = metric_mapping.get(strategy.upper(), DistanceStrategy.COSINE)
|
|
595
|
+
|
|
596
|
+
# self.logger.debug(
|
|
597
|
+
# f"PgVector: using distance strategy → {strategy}"
|
|
598
|
+
# )
|
|
599
|
+
|
|
600
|
+
# Convert numpy array to list if needed
|
|
601
|
+
if isinstance(query_embedding, np.ndarray):
|
|
602
|
+
query_embedding = query_embedding.tolist()
|
|
603
|
+
|
|
604
|
+
if strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
605
|
+
return embedding_column_obj.l2_distance(query_embedding)
|
|
606
|
+
elif strategy == DistanceStrategy.COSINE:
|
|
607
|
+
return embedding_column_obj.cosine_distance(query_embedding)
|
|
608
|
+
elif strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
|
609
|
+
return embedding_column_obj.max_inner_product(query_embedding)
|
|
610
|
+
elif strategy == DistanceStrategy.DOT_PRODUCT:
|
|
611
|
+
# Note: pgvector doesn't have dot_product, using max_inner_product
|
|
612
|
+
return embedding_column_obj.max_inner_product(query_embedding)
|
|
613
|
+
else:
|
|
614
|
+
raise ValueError(
|
|
615
|
+
f"Got unexpected value for distance: {strategy}. "
|
|
616
|
+
f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
async def similarity_search(
|
|
620
|
+
self,
|
|
621
|
+
query: str,
|
|
622
|
+
table: str = None,
|
|
623
|
+
schema: str = None,
|
|
624
|
+
k: Optional[int] = None,
|
|
625
|
+
limit: int = None,
|
|
626
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
627
|
+
score_threshold: Optional[float] = None,
|
|
628
|
+
metric: str = None,
|
|
629
|
+
embedding_column: str = 'embedding',
|
|
630
|
+
content_column: str = 'document',
|
|
631
|
+
metadata_column: str = 'cmetadata',
|
|
632
|
+
id_column: str = 'id',
|
|
633
|
+
additional_columns: Optional[List[str]] = None
|
|
634
|
+
) -> List[SearchResult]:
|
|
635
|
+
"""
|
|
636
|
+
Perform similarity search with optional threshold filtering.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
query: The search query text
|
|
640
|
+
table: Table name (optional, uses default if not provided)
|
|
641
|
+
schema: Schema name (optional, uses default if not provided)
|
|
642
|
+
limit: Maximum number of results to return
|
|
643
|
+
score_threshold: Maximum distance threshold
|
|
644
|
+
results with distance > threshold will be filtered out)
|
|
645
|
+
metadata_filters: Dictionary of metadata filters to apply
|
|
646
|
+
metric: Distance metric to use ('COSINE', 'L2', 'IP')
|
|
647
|
+
embedding_column: Name of the embedding column
|
|
648
|
+
content_column: Name of the content column
|
|
649
|
+
metadata_column: Name of the metadata column
|
|
650
|
+
id_column: Name of the ID column
|
|
651
|
+
additional_columns: List of additional columns to include in results.
|
|
652
|
+
Returns:
|
|
653
|
+
List of SearchResult objects with content, metadata, score, collection_id, and record_id
|
|
654
|
+
"""
|
|
655
|
+
if not self._connected:
|
|
656
|
+
await self.connection()
|
|
657
|
+
|
|
658
|
+
table = table or self.table_name
|
|
659
|
+
schema = schema or self.schema
|
|
660
|
+
|
|
661
|
+
if k and not limit:
|
|
662
|
+
limit = k
|
|
663
|
+
if not limit:
|
|
664
|
+
limit = 10
|
|
665
|
+
|
|
666
|
+
# Step 1: Ensure the ORM class exists
|
|
667
|
+
if not self.embedding_store:
|
|
668
|
+
self.embedding_store = self._define_collection_store(
|
|
669
|
+
table=table,
|
|
670
|
+
schema=schema,
|
|
671
|
+
dimension=self.dimension,
|
|
672
|
+
id_column=self._id_column,
|
|
673
|
+
embedding_column=embedding_column,
|
|
674
|
+
document_column=content_column,
|
|
675
|
+
metadata_column=metadata_column,
|
|
676
|
+
text_column=self._text_column,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Step 2: Embed the query
|
|
680
|
+
query_embedding = self._embed_.embed_query(query)
|
|
681
|
+
|
|
682
|
+
# Get the actual column objects
|
|
683
|
+
content_col = getattr(self.embedding_store, content_column)
|
|
684
|
+
metadata_col = getattr(self.embedding_store, metadata_column)
|
|
685
|
+
embedding_col = getattr(self.embedding_store, embedding_column)
|
|
686
|
+
id_col = getattr(self.embedding_store, id_column)
|
|
687
|
+
collection_id_col = getattr(self.embedding_store, 'collection_id')
|
|
688
|
+
|
|
689
|
+
# Get the distance expression using the appropriate method
|
|
690
|
+
distance_expr = self.get_distance_strategy(
|
|
691
|
+
embedding_col,
|
|
692
|
+
query_embedding,
|
|
693
|
+
metric=metric
|
|
694
|
+
).label("distance")
|
|
695
|
+
# self.logger.debug(f"Compiled distance expr → {distance_expr}")
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
# Build the select columns list
|
|
699
|
+
select_columns = [
|
|
700
|
+
id_col,
|
|
701
|
+
content_col,
|
|
702
|
+
metadata_col,
|
|
703
|
+
distance_expr,
|
|
704
|
+
collection_id_col,
|
|
705
|
+
]
|
|
706
|
+
|
|
707
|
+
# Add additional columns dynamically using literal_column (no validation)
|
|
708
|
+
if additional_columns:
|
|
709
|
+
for col_name in additional_columns:
|
|
710
|
+
# Use literal_column to reference any column name without ORM validation
|
|
711
|
+
additional_col = literal_column(f'"{col_name}"').label(col_name)
|
|
712
|
+
select_columns.append(additional_col)
|
|
713
|
+
self.logger.debug(f"Added dynamic column: {col_name}")
|
|
714
|
+
|
|
715
|
+
# Step 5: Construct statement
|
|
716
|
+
stmt = (
|
|
717
|
+
select(*select_columns)
|
|
718
|
+
.select_from(self.embedding_store) # Explicitly specify the table
|
|
719
|
+
.order_by(asc(distance_expr))
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
# Apply threshold filter if provided
|
|
723
|
+
if score_threshold is not None:
|
|
724
|
+
stmt = stmt.where(distance_expr <= score_threshold)
|
|
725
|
+
|
|
726
|
+
if limit:
|
|
727
|
+
stmt = stmt.limit(limit)
|
|
728
|
+
|
|
729
|
+
# 6) Apply any JSONB metadata filters
|
|
730
|
+
if metadata_filters:
|
|
731
|
+
for key, val in metadata_filters.items():
|
|
732
|
+
if isinstance(val, bool):
|
|
733
|
+
# Handle boolean values properly in JSONB
|
|
734
|
+
stmt = stmt.where(
|
|
735
|
+
metadata_col[key].astext.cast(sqlalchemy.Boolean) == val
|
|
736
|
+
)
|
|
737
|
+
else:
|
|
738
|
+
stmt = stmt.where(
|
|
739
|
+
metadata_col[key].astext == str(val)
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
try:
|
|
743
|
+
# Execute query
|
|
744
|
+
async with self.session() as session:
|
|
745
|
+
result = await session.execute(stmt)
|
|
746
|
+
rows = result.fetchall()
|
|
747
|
+
# Create enhanced SearchResult objects
|
|
748
|
+
results = []
|
|
749
|
+
for row in rows:
|
|
750
|
+
metadata = row[2]
|
|
751
|
+
metadata['collection_id'] = row[4]
|
|
752
|
+
# Add additional columns as a dictionary (starting from index 5)
|
|
753
|
+
if additional_columns:
|
|
754
|
+
for i, col_name in enumerate(additional_columns):
|
|
755
|
+
metadata[col_name] = row[5 + i]
|
|
756
|
+
# Create an enhanced SearchResult with additional fields
|
|
757
|
+
search_result = SearchResult(
|
|
758
|
+
id=row[0],
|
|
759
|
+
content=row[1], # content_col
|
|
760
|
+
metadata=metadata, # metadata_col
|
|
761
|
+
score=row[3] # distance
|
|
762
|
+
)
|
|
763
|
+
results.append(search_result)
|
|
764
|
+
|
|
765
|
+
return results
|
|
766
|
+
except Exception as e:
|
|
767
|
+
self.logger.error(f"Error during similarity search: {e}")
|
|
768
|
+
raise
|
|
769
|
+
|
|
770
|
+
def get_vector(self, metric_type: str = None, **kwargs):
|
|
771
|
+
raise NotImplementedError("This method is part of the old implementation.")
|
|
772
|
+
|
|
773
|
+
async def drop_collection(self, table: str, schema: str = 'public') -> None:
|
|
774
|
+
"""
|
|
775
|
+
Drops the specified table in the given schema.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
table: The name of the table to drop.
|
|
779
|
+
schema: The database schema where the table resides.
|
|
780
|
+
"""
|
|
781
|
+
if not self._connected:
|
|
782
|
+
await self.connection()
|
|
783
|
+
|
|
784
|
+
full_table_name = f"{schema}.{table}"
|
|
785
|
+
async with self._connection.begin() as conn:
|
|
786
|
+
await conn.execute(text(f"DROP TABLE IF EXISTS {full_table_name}"))
|
|
787
|
+
self.logger.info(f"Table '{full_table_name}' dropped successfully.")
|
|
788
|
+
|
|
789
|
+
async def prepare_embedding_table(
|
|
790
|
+
self,
|
|
791
|
+
table: str,
|
|
792
|
+
schema: str = 'public',
|
|
793
|
+
conn: AsyncEngine = None,
|
|
794
|
+
id_column: str = 'id',
|
|
795
|
+
embedding_column: str = 'embedding',
|
|
796
|
+
document_column: str = 'document',
|
|
797
|
+
metadata_column: str = 'cmetadata',
|
|
798
|
+
dimension: int = 768,
|
|
799
|
+
colbert_dimension: int = 128, # ColBERT token dimension
|
|
800
|
+
use_jsonb: bool = True,
|
|
801
|
+
drop_columns: bool = False,
|
|
802
|
+
create_all_indexes: bool = True,
|
|
803
|
+
**kwargs
|
|
804
|
+
):
|
|
805
|
+
"""
|
|
806
|
+
Prepare a Postgres Table as an embedding table in PostgreSQL with advanced features.
|
|
807
|
+
This method prepares a table with the following columns:
|
|
808
|
+
- id: unique identifier (String)
|
|
809
|
+
- embedding: the vector column (Vector(dimension) or JSONB)
|
|
810
|
+
- document: text column containing the document
|
|
811
|
+
- collection_id: UUID column for collection identification.
|
|
812
|
+
- metadata: JSONB column for metadata
|
|
813
|
+
- Additional columns based on the provided `columns` list
|
|
814
|
+
- Enhanced indexing strategies for efficient querying
|
|
815
|
+
- Support for multiple distance strategies (COSINE, L2, IP, etc.)
|
|
816
|
+
Args:
|
|
817
|
+
- tablename (str): Name of the table to create.
|
|
818
|
+
- embedding_column (str): Name of the column for storing embeddings.
|
|
819
|
+
- document_column (str): Name of the column for storing document text.
|
|
820
|
+
- metadata_column (str): Name of the column for storing metadata.
|
|
821
|
+
- dimension (int): Dimension of the embedding vector.
|
|
822
|
+
- id_column (str): Name of the column for storing unique identifiers.
|
|
823
|
+
- use_jsonb (bool): Whether to use JSONB for metadata storage.
|
|
824
|
+
- drop_columns (bool): Whether to drop existing columns.
|
|
825
|
+
- create_all_indexes (bool): Whether to create all distance strategies.
|
|
826
|
+
"""
|
|
827
|
+
tablename = f"{schema}.{table}"
|
|
828
|
+
# Drop existing columns if requested
|
|
829
|
+
if drop_columns:
|
|
830
|
+
columns_to_drop = [
|
|
831
|
+
document_column, embedding_column, metadata_column,
|
|
832
|
+
'token_embeddings', 'num_tokens'
|
|
833
|
+
]
|
|
834
|
+
for column in columns_to_drop:
|
|
835
|
+
await conn.execute(
|
|
836
|
+
sqlalchemy.text(
|
|
837
|
+
f'ALTER TABLE {tablename} DROP COLUMN IF EXISTS {column};'
|
|
838
|
+
)
|
|
839
|
+
)
|
|
840
|
+
# Create metadata column as a jsonb field
|
|
841
|
+
if use_jsonb:
|
|
842
|
+
await conn.execute(
|
|
843
|
+
sqlalchemy.text(
|
|
844
|
+
f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {metadata_column} JSONB;'
|
|
845
|
+
)
|
|
846
|
+
)
|
|
847
|
+
# Use pgvector type for dense embeddings
|
|
848
|
+
await conn.execute(
|
|
849
|
+
sqlalchemy.text(
|
|
850
|
+
f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {embedding_column} vector({dimension});'
|
|
851
|
+
)
|
|
852
|
+
)
|
|
853
|
+
# Add ColBERT columns for token-level embeddings
|
|
854
|
+
await conn.execute(
|
|
855
|
+
sqlalchemy.text(
|
|
856
|
+
f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS token_embeddings vector({colbert_dimension})[];'
|
|
857
|
+
)
|
|
858
|
+
)
|
|
859
|
+
await conn.execute(
|
|
860
|
+
sqlalchemy.text(
|
|
861
|
+
f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS num_tokens INTEGER;'
|
|
862
|
+
)
|
|
863
|
+
)
|
|
864
|
+
# Create the additional columns
|
|
865
|
+
for col_name, col_type in [
|
|
866
|
+
(document_column, 'TEXT'),
|
|
867
|
+
(id_column, 'varchar'),
|
|
868
|
+
]:
|
|
869
|
+
await conn.execute(
|
|
870
|
+
sqlalchemy.text(
|
|
871
|
+
f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {col_name} {col_type};'
|
|
872
|
+
)
|
|
873
|
+
)
|
|
874
|
+
# Create the Collection Column:
|
|
875
|
+
await conn.execute(
|
|
876
|
+
sqlalchemy.text(
|
|
877
|
+
f"ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS collection_id UUID;"
|
|
878
|
+
)
|
|
879
|
+
)
|
|
880
|
+
await conn.execute(
|
|
881
|
+
sqlalchemy.text(
|
|
882
|
+
f"ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS collection_id UUID DEFAULT uuid_generate_v4();"
|
|
883
|
+
)
|
|
884
|
+
)
|
|
885
|
+
# Set the value on null values before declaring not null:
|
|
886
|
+
await conn.execute(
|
|
887
|
+
sqlalchemy.text(
|
|
888
|
+
f"UPDATE {tablename} SET collection_id = uuid_generate_v4() WHERE collection_id IS NULL;"
|
|
889
|
+
)
|
|
890
|
+
)
|
|
891
|
+
await conn.execute(
|
|
892
|
+
sqlalchemy.text(
|
|
893
|
+
f"ALTER TABLE {tablename} ALTER COLUMN collection_id SET NOT NULL;"
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
await conn.execute(
|
|
897
|
+
sqlalchemy.text(
|
|
898
|
+
f"CREATE UNIQUE INDEX IF NOT EXISTS idx_{table}_{schema}_collection_id ON {tablename} (collection_id);"
|
|
899
|
+
)
|
|
900
|
+
)
|
|
901
|
+
# ✅ CREATE COMPREHENSIVE INDEXES
|
|
902
|
+
if create_all_indexes:
|
|
903
|
+
await self._create_all_indexes(conn, tablename, embedding_column)
|
|
904
|
+
else:
|
|
905
|
+
# Create index only for current strategy
|
|
906
|
+
distance_strategy_ops = {
|
|
907
|
+
DistanceStrategy.COSINE: "vector_cosine_ops",
|
|
908
|
+
DistanceStrategy.EUCLIDEAN_DISTANCE: "vector_l2_ops",
|
|
909
|
+
DistanceStrategy.MAX_INNER_PRODUCT: "vector_ip_ops",
|
|
910
|
+
DistanceStrategy.DOT_PRODUCT: "vector_ip_ops"
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
ops = distance_strategy_ops.get(self.distance_strategy, "vector_cosine_ops")
|
|
914
|
+
strategy_name = str(self.distance_strategy).rsplit('.', maxsplit=1)[-1].lower()
|
|
915
|
+
|
|
916
|
+
await conn.execute(
|
|
917
|
+
sqlalchemy.text(
|
|
918
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_{strategy_name} "
|
|
919
|
+
f"ON {tablename} USING ivfflat ({embedding_column} {ops});"
|
|
920
|
+
)
|
|
921
|
+
)
|
|
922
|
+
print(f"✅ Created {strategy_name.upper()} index")
|
|
923
|
+
|
|
924
|
+
# Create ColBERT-specific indexes
|
|
925
|
+
await self._create_colbert_indexes(conn, tablename)
|
|
926
|
+
|
|
927
|
+
# Create JSONB indexes for better performance
|
|
928
|
+
await self._create_jsonb_indexes(
|
|
929
|
+
conn,
|
|
930
|
+
tablename,
|
|
931
|
+
metadata_column,
|
|
932
|
+
id_column
|
|
933
|
+
)
|
|
934
|
+
# Ensure the table is ready for embedding operations
|
|
935
|
+
self.embedding_store = self._define_collection_store(
|
|
936
|
+
table=table,
|
|
937
|
+
schema=schema,
|
|
938
|
+
dimension=dimension,
|
|
939
|
+
id_column=id_column,
|
|
940
|
+
embedding_column=embedding_column,
|
|
941
|
+
document_column=self._document_column
|
|
942
|
+
)
|
|
943
|
+
return True
|
|
944
|
+
|
|
945
|
+
async def _create_all_indexes(self, conn, tablename: str, embedding_column: str):
|
|
946
|
+
"""Create all standard vector indexes."""
|
|
947
|
+
print("🔧 Creating indexes for all distance strategies...")
|
|
948
|
+
|
|
949
|
+
# COSINE index (most common for text embeddings)
|
|
950
|
+
await conn.execute(
|
|
951
|
+
sqlalchemy.text(
|
|
952
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_cosine "
|
|
953
|
+
f"ON {tablename} USING ivfflat ({embedding_column} vector_cosine_ops);"
|
|
954
|
+
)
|
|
955
|
+
)
|
|
956
|
+
print("✅ Created COSINE index")
|
|
957
|
+
|
|
958
|
+
# L2/Euclidean index
|
|
959
|
+
await conn.execute(
|
|
960
|
+
sqlalchemy.text(
|
|
961
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_l2 "
|
|
962
|
+
f"ON {tablename} USING ivfflat ({embedding_column} vector_l2_ops);"
|
|
963
|
+
)
|
|
964
|
+
)
|
|
965
|
+
print("✅ Created L2 index")
|
|
966
|
+
|
|
967
|
+
# Inner Product index
|
|
968
|
+
try:
|
|
969
|
+
await conn.execute(
|
|
970
|
+
sqlalchemy.text(
|
|
971
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_ip "
|
|
972
|
+
f"ON {tablename} USING ivfflat ({embedding_column} vector_ip_ops);"
|
|
973
|
+
)
|
|
974
|
+
)
|
|
975
|
+
print("✅ Created Inner Product index")
|
|
976
|
+
except Exception as e:
|
|
977
|
+
print(f"⚠️ Inner Product index creation failed: {e}")
|
|
978
|
+
|
|
979
|
+
# HNSW indexes for better performance (requires more memory)
|
|
980
|
+
try:
|
|
981
|
+
await conn.execute(
|
|
982
|
+
sqlalchemy.text(
|
|
983
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_hnsw_cosine "
|
|
984
|
+
f"ON {tablename} USING hnsw ({embedding_column} vector_cosine_ops);"
|
|
985
|
+
)
|
|
986
|
+
)
|
|
987
|
+
print("✅ Created HNSW COSINE index")
|
|
988
|
+
|
|
989
|
+
await conn.execute(
|
|
990
|
+
sqlalchemy.text(
|
|
991
|
+
f"""CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_hnsw_l2
|
|
992
|
+
ON {tablename} USING hnsw ({embedding_column} vector_l2_ops) WITH (
|
|
993
|
+
m = 16, -- graph connectivity (higher → better recall, more memory)
|
|
994
|
+
ef_construction = 200 -- controls indexing time vs. recall
|
|
995
|
+
);"""
|
|
996
|
+
)
|
|
997
|
+
)
|
|
998
|
+
print("✅ Created HNSW EUCLIDEAN index")
|
|
999
|
+
except Exception as e:
|
|
1000
|
+
print(f"⚠️ HNSW index creation failed (this is optional): {e}")
|
|
1001
|
+
|
|
1002
|
+
async def _create_colbert_indexes(self, conn, tablename: str):
|
|
1003
|
+
"""Create ColBERT-specific indexes for token embeddings."""
|
|
1004
|
+
print("🔧 Creating ColBERT indexes...")
|
|
1005
|
+
|
|
1006
|
+
try:
|
|
1007
|
+
# GIN index for array operations on token embeddings
|
|
1008
|
+
await conn.execute(
|
|
1009
|
+
sqlalchemy.text(
|
|
1010
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_token_embeddings_gin "
|
|
1011
|
+
f"ON {tablename} USING gin(token_embeddings);"
|
|
1012
|
+
)
|
|
1013
|
+
)
|
|
1014
|
+
print("✅ Created GIN index for token embeddings")
|
|
1015
|
+
|
|
1016
|
+
# Index on num_tokens for filtering
|
|
1017
|
+
await conn.execute(
|
|
1018
|
+
sqlalchemy.text(
|
|
1019
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_num_tokens "
|
|
1020
|
+
f"ON {tablename} (num_tokens);"
|
|
1021
|
+
)
|
|
1022
|
+
)
|
|
1023
|
+
print("✅ Created index for num_tokens")
|
|
1024
|
+
|
|
1025
|
+
# Partial index for non-null token embeddings
|
|
1026
|
+
await conn.execute(
|
|
1027
|
+
sqlalchemy.text(
|
|
1028
|
+
f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_has_tokens "
|
|
1029
|
+
f"ON {tablename} (id) WHERE token_embeddings IS NOT NULL;"
|
|
1030
|
+
)
|
|
1031
|
+
)
|
|
1032
|
+
print("✅ Created partial index for documents with token embeddings")
|
|
1033
|
+
|
|
1034
|
+
except Exception as e:
|
|
1035
|
+
print(f"⚠️ ColBERT index creation failed: {e}")
|
|
1036
|
+
|
|
1037
|
+
async def create_embedding_table(
|
|
1038
|
+
self,
|
|
1039
|
+
table: str,
|
|
1040
|
+
columns: List[str],
|
|
1041
|
+
schema: str = 'public',
|
|
1042
|
+
embedding_column: str = 'embedding',
|
|
1043
|
+
document_column: str = 'document',
|
|
1044
|
+
metadata_column: str = 'cmetadata',
|
|
1045
|
+
dimension: int = None,
|
|
1046
|
+
id_column: str = 'id',
|
|
1047
|
+
use_jsonb: bool = False,
|
|
1048
|
+
drop_columns: bool = True,
|
|
1049
|
+
create_all_indexes: bool = True,
|
|
1050
|
+
**kwargs
|
|
1051
|
+
):
|
|
1052
|
+
"""
|
|
1053
|
+
Create an embedding table in PostgreSQL with advanced features.
|
|
1054
|
+
This method creates a table with the following columns:
|
|
1055
|
+
- id: unique identifier (String)
|
|
1056
|
+
- embedding: the vector column (Vector(dimension) or JSONB)
|
|
1057
|
+
- document: text column containing the document
|
|
1058
|
+
- cmetadata: JSONB column for metadata
|
|
1059
|
+
- Additional columns based on the provided `columns` list
|
|
1060
|
+
- Enhanced indexing strategies for efficient querying
|
|
1061
|
+
- Support for multiple distance strategies (COSINE, L2, IP, etc.)
|
|
1062
|
+
Args:
|
|
1063
|
+
- table (str): Name of the table to create.
|
|
1064
|
+
- columns (List[str]): List of column names to include in the table.
|
|
1065
|
+
- schema (str): Database schema where the table will be created.
|
|
1066
|
+
- embedding_column (str): Name of the column for storing embeddings.
|
|
1067
|
+
- document_column (str): Name of the column for storing document text.
|
|
1068
|
+
- metadata_column (str): Name of the column for storing metadata.
|
|
1069
|
+
- dimension (int): Dimension of the embedding vector.
|
|
1070
|
+
- id_column (str): Name of the column for storing unique identifiers.
|
|
1071
|
+
- use_jsonb (bool): Whether to use JSONB for metadata storage.
|
|
1072
|
+
- drop_columns (bool): Whether to drop existing columns.
|
|
1073
|
+
- create_all_indexes (bool): Whether to create all distance strategies.
|
|
1074
|
+
|
|
1075
|
+
Enhanced embedding table creation with JSONB strategy for better semantic search.
|
|
1076
|
+
|
|
1077
|
+
This approach creates multiple document representations:
|
|
1078
|
+
1. Primary search content (emphasizing store ID)
|
|
1079
|
+
2. Location-based content
|
|
1080
|
+
3. Structured metadata for filtering
|
|
1081
|
+
4. Multiple embedding variations
|
|
1082
|
+
"""
|
|
1083
|
+
tablename = f'{schema}.{table}'
|
|
1084
|
+
cols = ', '.join(columns)
|
|
1085
|
+
_qry = f'SELECT {cols} FROM {tablename};'
|
|
1086
|
+
dimension = dimension or self.dimension
|
|
1087
|
+
|
|
1088
|
+
# Generate a sample embedding to determine its dimension
|
|
1089
|
+
sample_vector = self._embed_.embedding.embed_query("sample text")
|
|
1090
|
+
vector_dim = len(sample_vector)
|
|
1091
|
+
self.logger.notice(
|
|
1092
|
+
f"USING EMBED {self._embed_} with dimension {vector_dim}"
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if vector_dim != dimension:
|
|
1096
|
+
raise ValueError(
|
|
1097
|
+
f"Expected embedding dimension {dimension}, but got {vector_dim}"
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
async with self._connection.begin() as conn:
|
|
1101
|
+
result = await conn.execute(sqlalchemy.text(_qry))
|
|
1102
|
+
rows = result.fetchall()
|
|
1103
|
+
|
|
1104
|
+
await self.prepare_embedding_table(
|
|
1105
|
+
table=table,
|
|
1106
|
+
schema=schema,
|
|
1107
|
+
embedding_column=embedding_column,
|
|
1108
|
+
document_column=document_column,
|
|
1109
|
+
metadata_column=metadata_column,
|
|
1110
|
+
dimension=dimension,
|
|
1111
|
+
id_column=id_column,
|
|
1112
|
+
use_jsonb=use_jsonb,
|
|
1113
|
+
drop_columns=drop_columns,
|
|
1114
|
+
create_all_indexes=create_all_indexes,
|
|
1115
|
+
**kwargs
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
# Populate the embedding data
|
|
1119
|
+
for i, row in enumerate(rows):
|
|
1120
|
+
_id = getattr(row, id_column)
|
|
1121
|
+
metadata = {col: getattr(row, col) for col in columns}
|
|
1122
|
+
data = await self._create_metadata_structure(metadata, id_column, _id)
|
|
1123
|
+
|
|
1124
|
+
# Generate embedding
|
|
1125
|
+
searchable_text = data['structured_search']
|
|
1126
|
+
print(f"🔍 Row {i + 1}/{len(rows)} - {_id}")
|
|
1127
|
+
print(f" Text: {searchable_text[:100]}...")
|
|
1128
|
+
|
|
1129
|
+
vector = self._embed_.embedding.embed_query(searchable_text)
|
|
1130
|
+
vector_str = "[" + ",".join(str(v) for v in vector) + "]"
|
|
1131
|
+
|
|
1132
|
+
await conn.execute(
|
|
1133
|
+
sqlalchemy.text(f"""
|
|
1134
|
+
UPDATE {tablename}
|
|
1135
|
+
SET {embedding_column} = :embeddings,
|
|
1136
|
+
{document_column} = :document,
|
|
1137
|
+
{metadata_column} = :metadata
|
|
1138
|
+
WHERE {id_column} = :id
|
|
1139
|
+
"""),
|
|
1140
|
+
{
|
|
1141
|
+
"embeddings": vector_str,
|
|
1142
|
+
"document": searchable_text,
|
|
1143
|
+
"metadata": json_encoder(data),
|
|
1144
|
+
"id": _id
|
|
1145
|
+
}
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
print("✅ Updated Table embeddings with comprehensive indexes.")
|
|
1149
|
+
|
|
1150
|
+
def _create_natural_searchable_text(
|
|
1151
|
+
self,
|
|
1152
|
+
metadata: dict,
|
|
1153
|
+
id_column: str,
|
|
1154
|
+
record_id: str
|
|
1155
|
+
) -> str:
|
|
1156
|
+
"""
|
|
1157
|
+
Create well-structured, natural language text with proper separation.
|
|
1158
|
+
|
|
1159
|
+
This creates clean, readable text that embedding models can understand better.
|
|
1160
|
+
"""
|
|
1161
|
+
# Start with the ID in multiple formats for exact matching
|
|
1162
|
+
text_parts = [
|
|
1163
|
+
f"ID: {record_id}",
|
|
1164
|
+
f"Identifier: {record_id}",
|
|
1165
|
+
id_column + ": " + record_id
|
|
1166
|
+
]
|
|
1167
|
+
|
|
1168
|
+
# Process each field to create natural language descriptions
|
|
1169
|
+
for key, value in metadata.items():
|
|
1170
|
+
if value is None or value == '':
|
|
1171
|
+
continue
|
|
1172
|
+
clean_value = value.strip() if isinstance(value, str) else str(value)
|
|
1173
|
+
text_parts.append(f"{key}: {clean_value}")
|
|
1174
|
+
# Add the field in natural language format
|
|
1175
|
+
clean_key = key.replace('_', ' ').title()
|
|
1176
|
+
text_parts.append(f"{clean_key}={clean_value}")
|
|
1177
|
+
|
|
1178
|
+
# Join with spaces and clean up
|
|
1179
|
+
searchable_text = ', '.join(text_parts) + '.'
|
|
1180
|
+
|
|
1181
|
+
return searchable_text
|
|
1182
|
+
|
|
1183
|
+
def _create_structured_search_text(self, metadata: dict, id_column: str, record_id: str) -> str:
|
|
1184
|
+
"""
|
|
1185
|
+
Create a more structured but still readable search text.
|
|
1186
|
+
|
|
1187
|
+
This emphasizes key-value relationships while staying readable.
|
|
1188
|
+
"""
|
|
1189
|
+
# ID section with emphasis
|
|
1190
|
+
kv_sections = [
|
|
1191
|
+
f"ID: {record_id}",
|
|
1192
|
+
f"Identifier: {record_id}",
|
|
1193
|
+
id_column + ": " + record_id
|
|
1194
|
+
]
|
|
1195
|
+
|
|
1196
|
+
# Key-value sections with clean separation
|
|
1197
|
+
for key, value in metadata.items():
|
|
1198
|
+
if value is None or value == '':
|
|
1199
|
+
continue
|
|
1200
|
+
|
|
1201
|
+
# Clean key-value representation
|
|
1202
|
+
clean_key = key.replace('_', ' ').title()
|
|
1203
|
+
kv_sections.append(f"{clean_key}: {value}")
|
|
1204
|
+
kv_sections.append(f"{key}: {value}")
|
|
1205
|
+
|
|
1206
|
+
# Combine with proper separation
|
|
1207
|
+
return ' | '.join(kv_sections)
|
|
1208
|
+
|
|
1209
|
+
async def _create_metadata_structure(
|
|
1210
|
+
self,
|
|
1211
|
+
metadata: dict,
|
|
1212
|
+
id_column: str,
|
|
1213
|
+
_id: str
|
|
1214
|
+
):
|
|
1215
|
+
"""Create a structured metadata representation for the document."""
|
|
1216
|
+
# Create a structured metadata representation
|
|
1217
|
+
enhanced_metadata = {
|
|
1218
|
+
"id": _id,
|
|
1219
|
+
id_column: _id,
|
|
1220
|
+
"_variants": [
|
|
1221
|
+
_id,
|
|
1222
|
+
_id.lower(),
|
|
1223
|
+
_id.upper()
|
|
1224
|
+
]
|
|
1225
|
+
}
|
|
1226
|
+
for key, value in metadata.items():
|
|
1227
|
+
enhanced_metadata[key] = value
|
|
1228
|
+
# Create searchable variants for key fields
|
|
1229
|
+
if value and isinstance(value, str):
|
|
1230
|
+
variants = [value, value.lower(), value.upper()]
|
|
1231
|
+
# Add variants without special characters
|
|
1232
|
+
clean_value = ''.join(c for c in str(value) if c.isalnum() or c.isspace())
|
|
1233
|
+
if clean_value != value:
|
|
1234
|
+
variants.append(clean_value)
|
|
1235
|
+
enhanced_metadata[f"_{key}_variants"] = list(set(variants))
|
|
1236
|
+
# create a full-text search field of searchable content
|
|
1237
|
+
enhanced_metadata['searchable_content'] = self._create_natural_searchable_text(
|
|
1238
|
+
metadata, id_column, _id
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
# Also create a structured search text that emphasizes important fields
|
|
1242
|
+
enhanced_metadata['structured_search'] = self._create_structured_search_text(
|
|
1243
|
+
metadata, id_column, _id
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
return enhanced_metadata
|
|
1247
|
+
|
|
1248
|
+
async def _create_jsonb_indexes(
|
|
1249
|
+
self,
|
|
1250
|
+
conn,
|
|
1251
|
+
tablename: str,
|
|
1252
|
+
metadata_col: str,
|
|
1253
|
+
id_column: str
|
|
1254
|
+
):
|
|
1255
|
+
"""Create optimized JSONB indexes for better search performance."""
|
|
1256
|
+
|
|
1257
|
+
print("🔧 Creating JSONB indexes on Metadata for optimized search...")
|
|
1258
|
+
|
|
1259
|
+
# Index for ID searches
|
|
1260
|
+
await conn.execute(
|
|
1261
|
+
sqlalchemy.text(f"""
|
|
1262
|
+
CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_{id_column}
|
|
1263
|
+
ON {tablename} USING BTREE (({metadata_col}->>'{id_column}'));
|
|
1264
|
+
""")
|
|
1265
|
+
)
|
|
1266
|
+
await conn.execute(
|
|
1267
|
+
sqlalchemy.text(f"""
|
|
1268
|
+
CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_id
|
|
1269
|
+
ON {tablename} USING BTREE (({metadata_col}->>'id'));
|
|
1270
|
+
""")
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1273
|
+
# GIN index for full-text search on searchable content
|
|
1274
|
+
await conn.execute(
|
|
1275
|
+
sqlalchemy.text(f"""
|
|
1276
|
+
CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_fulltext
|
|
1277
|
+
ON {tablename} USING GIN (to_tsvector('english', {metadata_col}->>'searchable_content'));
|
|
1278
|
+
""")
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
# GIN index for JSONB structure searches
|
|
1282
|
+
await conn.execute(
|
|
1283
|
+
sqlalchemy.text(f"""
|
|
1284
|
+
CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_metadata_gin
|
|
1285
|
+
ON {tablename} USING GIN ({metadata_col});
|
|
1286
|
+
""")
|
|
1287
|
+
)
|
|
1288
|
+
print("✅ Created optimized JSONB indexes")
|
|
1289
|
+
|
|
1290
|
+
async def add_colbert_document(
|
|
1291
|
+
self,
|
|
1292
|
+
document_id: str,
|
|
1293
|
+
content: str,
|
|
1294
|
+
token_embeddings: np.ndarray,
|
|
1295
|
+
table: str,
|
|
1296
|
+
schema: str = 'public',
|
|
1297
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
1298
|
+
document_column: str = 'document',
|
|
1299
|
+
metadata_column: str = 'cmetadata',
|
|
1300
|
+
id_column: str = 'id',
|
|
1301
|
+
**kwargs
|
|
1302
|
+
) -> None:
|
|
1303
|
+
"""
|
|
1304
|
+
Add a document with ColBERT token embeddings to the specified table.
|
|
1305
|
+
|
|
1306
|
+
Args:
|
|
1307
|
+
document_id: Unique identifier for the document
|
|
1308
|
+
content: The document text content
|
|
1309
|
+
token_embeddings: NumPy array of token embeddings (shape: [num_tokens, embedding_dim])
|
|
1310
|
+
table: The name of the table
|
|
1311
|
+
schema: The database schema where the table resides
|
|
1312
|
+
metadata: Optional metadata dictionary
|
|
1313
|
+
document_column: Name of the document content column
|
|
1314
|
+
metadata_column: Name of the metadata column
|
|
1315
|
+
id_column: Name of the ID column
|
|
1316
|
+
"""
|
|
1317
|
+
if not self._connected:
|
|
1318
|
+
await self.connection()
|
|
1319
|
+
|
|
1320
|
+
# Ensure the ORM table is initialized
|
|
1321
|
+
if self.embedding_store is None:
|
|
1322
|
+
self.embedding_store = self._define_collection_store(
|
|
1323
|
+
table=table,
|
|
1324
|
+
schema=schema,
|
|
1325
|
+
dimension=self.dimension,
|
|
1326
|
+
id_column=id_column,
|
|
1327
|
+
document_column=document_column,
|
|
1328
|
+
metadata_column=metadata_column,
|
|
1329
|
+
text_column=self._text_column,
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
# Convert numpy array to list format for PostgreSQL
|
|
1333
|
+
if isinstance(token_embeddings, np.ndarray):
|
|
1334
|
+
token_embeddings_list = token_embeddings.tolist()
|
|
1335
|
+
else:
|
|
1336
|
+
token_embeddings_list = token_embeddings
|
|
1337
|
+
|
|
1338
|
+
num_tokens = len(token_embeddings_list)
|
|
1339
|
+
|
|
1340
|
+
# Prepare the insert/upsert data
|
|
1341
|
+
values = {
|
|
1342
|
+
id_column: document_id,
|
|
1343
|
+
document_column: content,
|
|
1344
|
+
'token_embeddings': token_embeddings_list,
|
|
1345
|
+
'num_tokens': num_tokens,
|
|
1346
|
+
metadata_column: metadata or {}
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
# Build insert statement with upsert capability
|
|
1350
|
+
insert_stmt = insert(self.embedding_store).values(values)
|
|
1351
|
+
|
|
1352
|
+
# Create upsert statement (ON CONFLICT DO UPDATE)
|
|
1353
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
1354
|
+
index_elements=[id_column],
|
|
1355
|
+
set_={
|
|
1356
|
+
# document_column: insert_stmt.excluded.__getattr__(document_column),
|
|
1357
|
+
document_column: getattr(insert_stmt.excluded, document_column),
|
|
1358
|
+
'token_embeddings': insert_stmt.excluded.token_embeddings,
|
|
1359
|
+
'num_tokens': insert_stmt.excluded.num_tokens,
|
|
1360
|
+
metadata_column: getattr(insert_stmt.excluded, metadata_column),
|
|
1361
|
+
}
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
try:
|
|
1365
|
+
async with self._connection.begin() as conn:
|
|
1366
|
+
await conn.execute(upsert_stmt)
|
|
1367
|
+
|
|
1368
|
+
self.logger.info(
|
|
1369
|
+
f"Successfully added ColBERT document '{document_id}' with {num_tokens} tokens to '{schema}.{table}'"
|
|
1370
|
+
)
|
|
1371
|
+
except Exception as e:
|
|
1372
|
+
self.logger.error(f"Error adding ColBERT document: {e}")
|
|
1373
|
+
raise
|
|
1374
|
+
|
|
1375
|
+
async def colbert_search(
|
|
1376
|
+
self,
|
|
1377
|
+
query_tokens: np.ndarray,
|
|
1378
|
+
table: str,
|
|
1379
|
+
schema: str = 'public',
|
|
1380
|
+
top_k: int = 10,
|
|
1381
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
1382
|
+
min_tokens: Optional[int] = None,
|
|
1383
|
+
max_tokens: Optional[int] = None,
|
|
1384
|
+
id_column: str = 'id',
|
|
1385
|
+
document_column: str = 'document',
|
|
1386
|
+
metadata_column: str = 'cmetadata',
|
|
1387
|
+
additional_columns: Optional[List[str]] = None
|
|
1388
|
+
) -> List[SearchResult]:
|
|
1389
|
+
"""
|
|
1390
|
+
Perform ColBERT search with late interaction using MaxSim scoring.
|
|
1391
|
+
|
|
1392
|
+
Args:
|
|
1393
|
+
query_tokens: NumPy array of query token embeddings (shape: [num_query_tokens, embedding_dim])
|
|
1394
|
+
table: Table name
|
|
1395
|
+
schema: Schema name
|
|
1396
|
+
top_k: Number of results to return
|
|
1397
|
+
metadata_filters: Optional metadata filters
|
|
1398
|
+
min_tokens: Minimum number of tokens in documents to consider
|
|
1399
|
+
max_tokens: Maximum number of tokens in documents to consider
|
|
1400
|
+
id_column: Name of the ID column
|
|
1401
|
+
document_column: Name of the document content column
|
|
1402
|
+
metadata_column: Name of the metadata column
|
|
1403
|
+
additional_columns: Additional columns to include in results
|
|
1404
|
+
|
|
1405
|
+
Returns:
|
|
1406
|
+
List of SearchResult objects ordered by ColBERT score (descending)
|
|
1407
|
+
"""
|
|
1408
|
+
if not self._connected:
|
|
1409
|
+
await self.connection()
|
|
1410
|
+
|
|
1411
|
+
# Ensure the ORM table is initialized
|
|
1412
|
+
if self.embedding_store is None:
|
|
1413
|
+
self.embedding_store = self._define_collection_store(
|
|
1414
|
+
table=table,
|
|
1415
|
+
schema=schema,
|
|
1416
|
+
dimension=self.dimension,
|
|
1417
|
+
id_column=id_column,
|
|
1418
|
+
document_column=document_column,
|
|
1419
|
+
metadata_column=metadata_column,
|
|
1420
|
+
text_column=self._text_column,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
# Convert query tokens to list format
|
|
1424
|
+
if isinstance(query_tokens, np.ndarray):
|
|
1425
|
+
query_tokens_list = query_tokens.tolist()
|
|
1426
|
+
else:
|
|
1427
|
+
query_tokens_list = query_tokens
|
|
1428
|
+
|
|
1429
|
+
# Get column objects
|
|
1430
|
+
id_col = getattr(self.embedding_store, id_column)
|
|
1431
|
+
content_col = getattr(self.embedding_store, document_column)
|
|
1432
|
+
metadata_col = getattr(self.embedding_store, metadata_column)
|
|
1433
|
+
token_embeddings_col = getattr(self.embedding_store, 'token_embeddings')
|
|
1434
|
+
num_tokens_col = getattr(self.embedding_store, 'num_tokens')
|
|
1435
|
+
collection_id_col = getattr(self.embedding_store, 'collection_id')
|
|
1436
|
+
|
|
1437
|
+
# Build select columns
|
|
1438
|
+
select_columns = [
|
|
1439
|
+
id_col,
|
|
1440
|
+
content_col,
|
|
1441
|
+
metadata_col,
|
|
1442
|
+
collection_id_col,
|
|
1443
|
+
func.max_sim(token_embeddings_col, query_tokens_list).label('colbert_score')
|
|
1444
|
+
]
|
|
1445
|
+
|
|
1446
|
+
# Add additional columns dynamically
|
|
1447
|
+
if additional_columns:
|
|
1448
|
+
for col_name in additional_columns:
|
|
1449
|
+
additional_col = literal_column(f'"{col_name}"').label(col_name)
|
|
1450
|
+
select_columns.append(additional_col)
|
|
1451
|
+
|
|
1452
|
+
# Build the query
|
|
1453
|
+
stmt = (
|
|
1454
|
+
select(*select_columns)
|
|
1455
|
+
.select_from(self.embedding_store)
|
|
1456
|
+
.where(token_embeddings_col.isnot(None)) # Only documents with token embeddings
|
|
1457
|
+
.order_by(func.max_sim(token_embeddings_col, query_tokens_list).desc())
|
|
1458
|
+
.limit(top_k)
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
# Apply token count filters
|
|
1462
|
+
if min_tokens is not None:
|
|
1463
|
+
stmt = stmt.where(num_tokens_col >= min_tokens)
|
|
1464
|
+
if max_tokens is not None:
|
|
1465
|
+
stmt = stmt.where(num_tokens_col <= max_tokens)
|
|
1466
|
+
|
|
1467
|
+
# Apply metadata filters
|
|
1468
|
+
if metadata_filters:
|
|
1469
|
+
for key, value in metadata_filters.items():
|
|
1470
|
+
stmt = stmt.where(metadata_col[key].astext == str(value))
|
|
1471
|
+
|
|
1472
|
+
try:
|
|
1473
|
+
async with self._connection.connect() as conn:
|
|
1474
|
+
result = await conn.execute(stmt)
|
|
1475
|
+
rows = result.fetchall()
|
|
1476
|
+
|
|
1477
|
+
# Create SearchResult objects
|
|
1478
|
+
results = []
|
|
1479
|
+
for row in rows:
|
|
1480
|
+
# Enhance metadata with additional info
|
|
1481
|
+
metadata = dict(row[2]) if row[2] else {}
|
|
1482
|
+
metadata['collection_id'] = row[3]
|
|
1483
|
+
metadata['colbert_score'] = float(row[4])
|
|
1484
|
+
|
|
1485
|
+
# Add additional columns to metadata
|
|
1486
|
+
if additional_columns:
|
|
1487
|
+
for i, col_name in enumerate(additional_columns):
|
|
1488
|
+
metadata[col_name] = row[5 + i]
|
|
1489
|
+
|
|
1490
|
+
search_result = SearchResult(
|
|
1491
|
+
id=row[0],
|
|
1492
|
+
content=row[1],
|
|
1493
|
+
metadata=metadata,
|
|
1494
|
+
score=float(row[4]) # ColBERT score
|
|
1495
|
+
)
|
|
1496
|
+
results.append(search_result)
|
|
1497
|
+
|
|
1498
|
+
self.logger.info(
|
|
1499
|
+
f"ColBERT search returned {len(results)} results from {schema}.{table}"
|
|
1500
|
+
)
|
|
1501
|
+
return results
|
|
1502
|
+
|
|
1503
|
+
except Exception as e:
|
|
1504
|
+
self.logger.error(f"Error during ColBERT search: {e}")
|
|
1505
|
+
raise
|
|
1506
|
+
|
|
1507
|
+
async def hybrid_search(
|
|
1508
|
+
self,
|
|
1509
|
+
query: str,
|
|
1510
|
+
query_tokens: Optional[np.ndarray] = None,
|
|
1511
|
+
table: str = None,
|
|
1512
|
+
schema: str = None,
|
|
1513
|
+
top_k: int = 10,
|
|
1514
|
+
dense_weight: float = 0.7,
|
|
1515
|
+
colbert_weight: float = 0.3,
|
|
1516
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
1517
|
+
**kwargs
|
|
1518
|
+
) -> List[SearchResult]:
|
|
1519
|
+
"""
|
|
1520
|
+
Perform hybrid search combining dense embeddings and ColBERT token matching.
|
|
1521
|
+
|
|
1522
|
+
Args:
|
|
1523
|
+
query: Text query
|
|
1524
|
+
query_tokens: Optional pre-computed query token embeddings
|
|
1525
|
+
table: Table name
|
|
1526
|
+
schema: Schema name
|
|
1527
|
+
top_k: Number of final results
|
|
1528
|
+
dense_weight: Weight for dense similarity scores (0-1)
|
|
1529
|
+
colbert_weight: Weight for ColBERT scores (0-1)
|
|
1530
|
+
metadata_filters: Metadata filters to apply
|
|
1531
|
+
|
|
1532
|
+
Returns:
|
|
1533
|
+
List of SearchResult objects with combined scores
|
|
1534
|
+
"""
|
|
1535
|
+
if not self._connected:
|
|
1536
|
+
await self.connection()
|
|
1537
|
+
|
|
1538
|
+
table = table or self.table_name
|
|
1539
|
+
schema = schema or self.schema
|
|
1540
|
+
|
|
1541
|
+
# Fetch more candidates for reranking
|
|
1542
|
+
candidate_count = min(top_k * 3, 100)
|
|
1543
|
+
|
|
1544
|
+
# Get dense similarity results
|
|
1545
|
+
dense_results = await self.similarity_search(
|
|
1546
|
+
query=query,
|
|
1547
|
+
table=table,
|
|
1548
|
+
schema=schema,
|
|
1549
|
+
limit=candidate_count,
|
|
1550
|
+
metadata_filters=metadata_filters,
|
|
1551
|
+
**kwargs
|
|
1552
|
+
)
|
|
1553
|
+
|
|
1554
|
+
# Get ColBERT results if token embeddings provided
|
|
1555
|
+
colbert_results = []
|
|
1556
|
+
if query_tokens is not None:
|
|
1557
|
+
colbert_results = await self.colbert_search(
|
|
1558
|
+
query_tokens=query_tokens,
|
|
1559
|
+
table=table,
|
|
1560
|
+
schema=schema,
|
|
1561
|
+
top_k=candidate_count,
|
|
1562
|
+
metadata_filters=metadata_filters
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
# Combine and rerank results
|
|
1566
|
+
combined_results = self._combine_search_results(
|
|
1567
|
+
dense_results=dense_results,
|
|
1568
|
+
colbert_results=colbert_results,
|
|
1569
|
+
dense_weight=dense_weight,
|
|
1570
|
+
colbert_weight=colbert_weight
|
|
1571
|
+
)
|
|
1572
|
+
|
|
1573
|
+
# Return top-k results
|
|
1574
|
+
return combined_results[:top_k]
|
|
1575
|
+
|
|
1576
|
+
def _combine_search_results(
|
|
1577
|
+
self,
|
|
1578
|
+
dense_results: List[SearchResult],
|
|
1579
|
+
colbert_results: List[SearchResult],
|
|
1580
|
+
dense_weight: float,
|
|
1581
|
+
colbert_weight: float
|
|
1582
|
+
) -> List[SearchResult]:
|
|
1583
|
+
"""
|
|
1584
|
+
Combine and rerank results from dense and ColBERT searches.
|
|
1585
|
+
"""
|
|
1586
|
+
# Create lookup dictionaries
|
|
1587
|
+
dense_lookup = {result.id: result for result in dense_results}
|
|
1588
|
+
colbert_lookup = {result.id: result for result in colbert_results}
|
|
1589
|
+
|
|
1590
|
+
# Get all unique document IDs
|
|
1591
|
+
all_ids = set(dense_lookup.keys()) | set(colbert_lookup.keys())
|
|
1592
|
+
|
|
1593
|
+
# Normalize scores to 0-1 range
|
|
1594
|
+
if dense_results:
|
|
1595
|
+
dense_scores = [r.score for r in dense_results]
|
|
1596
|
+
dense_min, dense_max = min(dense_scores), max(dense_scores)
|
|
1597
|
+
dense_range = dense_max - dense_min if dense_max != dense_min else 1
|
|
1598
|
+
else:
|
|
1599
|
+
dense_min, dense_range = 0, 1
|
|
1600
|
+
|
|
1601
|
+
if colbert_results:
|
|
1602
|
+
colbert_scores = [r.score for r in colbert_results]
|
|
1603
|
+
colbert_min, colbert_max = min(colbert_scores), max(colbert_scores)
|
|
1604
|
+
colbert_range = colbert_max - colbert_min if colbert_max != colbert_min else 1
|
|
1605
|
+
else:
|
|
1606
|
+
colbert_min, colbert_range = 0, 1
|
|
1607
|
+
|
|
1608
|
+
# Combine results
|
|
1609
|
+
combined_results = []
|
|
1610
|
+
for doc_id in all_ids:
|
|
1611
|
+
dense_result = dense_lookup.get(doc_id)
|
|
1612
|
+
colbert_result = colbert_lookup.get(doc_id)
|
|
1613
|
+
|
|
1614
|
+
# Normalize scores
|
|
1615
|
+
dense_score_norm = 0
|
|
1616
|
+
if dense_result:
|
|
1617
|
+
dense_score_norm = (dense_result.score - dense_min) / dense_range
|
|
1618
|
+
|
|
1619
|
+
colbert_score_norm = 0
|
|
1620
|
+
if colbert_result:
|
|
1621
|
+
colbert_score_norm = (colbert_result.score - colbert_min) / colbert_range
|
|
1622
|
+
|
|
1623
|
+
# Calculate combined score
|
|
1624
|
+
combined_score = (
|
|
1625
|
+
dense_weight * dense_score_norm +
|
|
1626
|
+
colbert_weight * colbert_score_norm
|
|
1627
|
+
)
|
|
1628
|
+
|
|
1629
|
+
# Use the result with more complete information
|
|
1630
|
+
primary_result = dense_result or colbert_result
|
|
1631
|
+
|
|
1632
|
+
# Create combined result
|
|
1633
|
+
combined_result = SearchResult(
|
|
1634
|
+
id=primary_result.id,
|
|
1635
|
+
content=primary_result.content,
|
|
1636
|
+
metadata={
|
|
1637
|
+
**primary_result.metadata,
|
|
1638
|
+
'dense_score': dense_result.score if dense_result else 0,
|
|
1639
|
+
'colbert_score': colbert_result.score if colbert_result else 0,
|
|
1640
|
+
'combined_score': combined_score
|
|
1641
|
+
},
|
|
1642
|
+
score=combined_score
|
|
1643
|
+
)
|
|
1644
|
+
combined_results.append(combined_result)
|
|
1645
|
+
|
|
1646
|
+
# Sort by combined score (descending)
|
|
1647
|
+
combined_results.sort(key=lambda x: x.score, reverse=True)
|
|
1648
|
+
|
|
1649
|
+
return combined_results
|
|
1650
|
+
|
|
1651
|
+
async def mmr_search(
|
|
1652
|
+
self,
|
|
1653
|
+
query: str,
|
|
1654
|
+
table: str = None,
|
|
1655
|
+
schema: str = None,
|
|
1656
|
+
k: int = 10,
|
|
1657
|
+
fetch_k: int = None,
|
|
1658
|
+
lambda_mult: float = 0.5,
|
|
1659
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
1660
|
+
score_threshold: Optional[float] = None,
|
|
1661
|
+
metric: str = None,
|
|
1662
|
+
embedding_column: str = 'embedding',
|
|
1663
|
+
content_column: str = 'document',
|
|
1664
|
+
metadata_column: str = 'cmetadata',
|
|
1665
|
+
id_column: str = 'id',
|
|
1666
|
+
additional_columns: Optional[List[str]] = None
|
|
1667
|
+
) -> List[SearchResult]:
|
|
1668
|
+
"""
|
|
1669
|
+
Perform Maximal Marginal Relevance (MMR) search to balance relevance and diversity.
|
|
1670
|
+
|
|
1671
|
+
MMR helps avoid redundant results by selecting documents that are relevant to the query
|
|
1672
|
+
but diverse from each other.
|
|
1673
|
+
|
|
1674
|
+
Args:
|
|
1675
|
+
query: The search query text
|
|
1676
|
+
table: Table name (optional, uses default if not provided)
|
|
1677
|
+
schema: Schema name (optional, uses default if not provided)
|
|
1678
|
+
k: Number of final results to return
|
|
1679
|
+
fetch_k: Number of candidate documents to fetch (default: k * 3)
|
|
1680
|
+
lambda_mult: MMR diversity parameter (0-1):
|
|
1681
|
+
- 1.0 = pure relevance (no diversity)
|
|
1682
|
+
- 0.0 = pure diversity (no relevance)
|
|
1683
|
+
- 0.5 = balanced (default)
|
|
1684
|
+
metadata_filters: Dictionary of metadata filters to apply
|
|
1685
|
+
score_threshold: Maximum distance threshold for initial candidates
|
|
1686
|
+
metric: Distance metric to use ('COSINE', 'L2', 'IP')
|
|
1687
|
+
embedding_column: Name of the embedding column
|
|
1688
|
+
content_column: Name of the content column
|
|
1689
|
+
metadata_column: Name of the metadata column
|
|
1690
|
+
id_column: Name of the ID column
|
|
1691
|
+
additional_columns: Additional columns to include in results
|
|
1692
|
+
|
|
1693
|
+
Returns:
|
|
1694
|
+
List of SearchResult objects selected via MMR algorithm
|
|
1695
|
+
"""
|
|
1696
|
+
if not self._connected:
|
|
1697
|
+
await self.connection()
|
|
1698
|
+
|
|
1699
|
+
# Default to fetching 3x more candidates than final results
|
|
1700
|
+
if fetch_k is None:
|
|
1701
|
+
fetch_k = max(k * 3, 20)
|
|
1702
|
+
|
|
1703
|
+
# Step 1: Get initial candidates using similarity search
|
|
1704
|
+
candidates = await self.similarity_search(
|
|
1705
|
+
query=query,
|
|
1706
|
+
table=table,
|
|
1707
|
+
schema=schema,
|
|
1708
|
+
limit=fetch_k,
|
|
1709
|
+
metadata_filters=metadata_filters,
|
|
1710
|
+
score_threshold=score_threshold,
|
|
1711
|
+
metric=metric,
|
|
1712
|
+
embedding_column=embedding_column,
|
|
1713
|
+
content_column=content_column,
|
|
1714
|
+
metadata_column=metadata_column,
|
|
1715
|
+
id_column=id_column,
|
|
1716
|
+
additional_columns=additional_columns
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
if len(candidates) <= k:
|
|
1720
|
+
# If we have fewer candidates than requested results, return all
|
|
1721
|
+
return candidates
|
|
1722
|
+
|
|
1723
|
+
# Step 2: Get embeddings for MMR computation
|
|
1724
|
+
# We need to fetch the actual embedding vectors for similarity computation
|
|
1725
|
+
candidate_embeddings = await self._fetch_embeddings_for_mmr(
|
|
1726
|
+
candidate_ids=[result.id for result in candidates],
|
|
1727
|
+
table=table,
|
|
1728
|
+
schema=schema,
|
|
1729
|
+
embedding_column=embedding_column,
|
|
1730
|
+
id_column=id_column
|
|
1731
|
+
)
|
|
1732
|
+
|
|
1733
|
+
# Step 3: Get query embedding
|
|
1734
|
+
query_embedding = self._embed_.embed_query(query)
|
|
1735
|
+
|
|
1736
|
+
# Step 4: Run MMR algorithm
|
|
1737
|
+
selected_results = self._mmr_algorithm(
|
|
1738
|
+
query_embedding=query_embedding,
|
|
1739
|
+
candidates=candidates,
|
|
1740
|
+
candidate_embeddings=candidate_embeddings,
|
|
1741
|
+
k=k,
|
|
1742
|
+
lambda_mult=lambda_mult,
|
|
1743
|
+
metric=metric or self.distance_strategy
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
self.logger.info(
|
|
1747
|
+
f"MMR search selected {len(selected_results)} results from {len(candidates)} candidates "
|
|
1748
|
+
f"(λ={lambda_mult})"
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1751
|
+
return selected_results
|
|
1752
|
+
|
|
1753
|
+
async def _fetch_embeddings_for_mmr(
|
|
1754
|
+
self,
|
|
1755
|
+
candidate_ids: List[str],
|
|
1756
|
+
table: str,
|
|
1757
|
+
schema: str,
|
|
1758
|
+
embedding_column: str,
|
|
1759
|
+
id_column: str
|
|
1760
|
+
) -> Dict[str, np.ndarray]:
|
|
1761
|
+
"""
|
|
1762
|
+
Fetch embedding vectors for candidate documents.
|
|
1763
|
+
|
|
1764
|
+
Args:
|
|
1765
|
+
candidate_ids: List of document IDs to fetch embeddings for
|
|
1766
|
+
table: Table name
|
|
1767
|
+
schema: Schema name
|
|
1768
|
+
embedding_column: Name of the embedding column
|
|
1769
|
+
id_column: Name of the ID column
|
|
1770
|
+
|
|
1771
|
+
Returns:
|
|
1772
|
+
Dictionary mapping document ID to embedding vector
|
|
1773
|
+
"""
|
|
1774
|
+
if not self.embedding_store:
|
|
1775
|
+
self.embedding_store = self._define_collection_store(
|
|
1776
|
+
table=table,
|
|
1777
|
+
schema=schema,
|
|
1778
|
+
dimension=self.dimension,
|
|
1779
|
+
id_column=self._id_column,
|
|
1780
|
+
embedding_column=embedding_column,
|
|
1781
|
+
document_column=self._document_column,
|
|
1782
|
+
metadata_column='cmetadata',
|
|
1783
|
+
text_column=self._text_column,
|
|
1784
|
+
)
|
|
1785
|
+
|
|
1786
|
+
# Get column objects
|
|
1787
|
+
id_col = getattr(self.embedding_store, id_column)
|
|
1788
|
+
embedding_col = getattr(self.embedding_store, embedding_column)
|
|
1789
|
+
|
|
1790
|
+
# Build query to fetch embeddings
|
|
1791
|
+
stmt = (
|
|
1792
|
+
select(id_col, embedding_col)
|
|
1793
|
+
.select_from(self.embedding_store)
|
|
1794
|
+
.where(id_col.in_(candidate_ids))
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
embeddings_dict = {}
|
|
1798
|
+
async with self.session() as session:
|
|
1799
|
+
result = await session.execute(stmt)
|
|
1800
|
+
rows = result.fetchall()
|
|
1801
|
+
|
|
1802
|
+
for row in rows:
|
|
1803
|
+
doc_id = row[0]
|
|
1804
|
+
embedding = row[1]
|
|
1805
|
+
|
|
1806
|
+
# Convert to numpy array if needed
|
|
1807
|
+
if isinstance(embedding, (list, tuple)):
|
|
1808
|
+
embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
|
|
1809
|
+
elif hasattr(embedding, '__array__'):
|
|
1810
|
+
embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
|
|
1811
|
+
else:
|
|
1812
|
+
# Handle pgvector Vector type
|
|
1813
|
+
embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
|
|
1814
|
+
|
|
1815
|
+
return embeddings_dict
|
|
1816
|
+
|
|
1817
|
+
def _mmr_algorithm(
|
|
1818
|
+
self,
|
|
1819
|
+
query_embedding: np.ndarray,
|
|
1820
|
+
candidates: List[SearchResult],
|
|
1821
|
+
candidate_embeddings: Dict[str, np.ndarray],
|
|
1822
|
+
k: int,
|
|
1823
|
+
lambda_mult: float,
|
|
1824
|
+
metric: str
|
|
1825
|
+
) -> List[SearchResult]:
|
|
1826
|
+
"""
|
|
1827
|
+
Core MMR algorithm implementation.
|
|
1828
|
+
|
|
1829
|
+
Args:
|
|
1830
|
+
query_embedding: Query embedding vector
|
|
1831
|
+
candidates: List of candidate SearchResult objects
|
|
1832
|
+
candidate_embeddings: Dictionary mapping doc ID to embedding vector
|
|
1833
|
+
k: Number of results to select
|
|
1834
|
+
lambda_mult: MMR diversity parameter (0-1)
|
|
1835
|
+
metric: Distance metric to use
|
|
1836
|
+
|
|
1837
|
+
Returns:
|
|
1838
|
+
List of selected SearchResult objects ordered by MMR score
|
|
1839
|
+
"""
|
|
1840
|
+
if len(candidates) <= k:
|
|
1841
|
+
return candidates
|
|
1842
|
+
|
|
1843
|
+
# Convert query embedding to numpy array
|
|
1844
|
+
if not isinstance(query_embedding, np.ndarray):
|
|
1845
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
1846
|
+
|
|
1847
|
+
# Prepare data structures
|
|
1848
|
+
selected_indices = []
|
|
1849
|
+
remaining_indices = list(range(len(candidates)))
|
|
1850
|
+
|
|
1851
|
+
# Step 1: Select the most relevant document first
|
|
1852
|
+
query_similarities = []
|
|
1853
|
+
for candidate in candidates:
|
|
1854
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
1855
|
+
if doc_embedding is not None:
|
|
1856
|
+
similarity = self._compute_similarity(query_embedding, doc_embedding, metric)
|
|
1857
|
+
query_similarities.append(similarity)
|
|
1858
|
+
else:
|
|
1859
|
+
# Fallback to distance score if embedding not available
|
|
1860
|
+
# Convert distance to similarity (lower distance = higher similarity)
|
|
1861
|
+
query_similarities.append(1.0 / (1.0 + candidate.score))
|
|
1862
|
+
|
|
1863
|
+
# Select the most similar document first
|
|
1864
|
+
best_idx = np.argmax(query_similarities)
|
|
1865
|
+
selected_indices.append(best_idx)
|
|
1866
|
+
remaining_indices.remove(best_idx)
|
|
1867
|
+
|
|
1868
|
+
# Step 2: Iteratively select remaining documents using MMR
|
|
1869
|
+
for _ in range(min(k - 1, len(remaining_indices))):
|
|
1870
|
+
mmr_scores = []
|
|
1871
|
+
|
|
1872
|
+
for idx in remaining_indices:
|
|
1873
|
+
candidate = candidates[idx]
|
|
1874
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
1875
|
+
|
|
1876
|
+
if doc_embedding is None:
|
|
1877
|
+
# Fallback scoring if embedding not available
|
|
1878
|
+
mmr_score = lambda_mult * query_similarities[idx]
|
|
1879
|
+
mmr_scores.append(mmr_score)
|
|
1880
|
+
continue
|
|
1881
|
+
|
|
1882
|
+
# Relevance: similarity to query
|
|
1883
|
+
relevance = query_similarities[idx]
|
|
1884
|
+
|
|
1885
|
+
# Diversity: maximum similarity to already selected documents
|
|
1886
|
+
max_similarity_to_selected = 0.0
|
|
1887
|
+
for selected_idx in selected_indices:
|
|
1888
|
+
selected_candidate = candidates[selected_idx]
|
|
1889
|
+
selected_embedding = candidate_embeddings.get(selected_candidate.id)
|
|
1890
|
+
|
|
1891
|
+
if selected_embedding is not None:
|
|
1892
|
+
similarity = self._compute_similarity(doc_embedding, selected_embedding, metric)
|
|
1893
|
+
max_similarity_to_selected = max(max_similarity_to_selected, similarity)
|
|
1894
|
+
|
|
1895
|
+
# MMR formula: λ * relevance - (1-λ) * max_similarity_to_selected
|
|
1896
|
+
mmr_score = (
|
|
1897
|
+
lambda_mult * relevance -
|
|
1898
|
+
(1.0 - lambda_mult) * max_similarity_to_selected
|
|
1899
|
+
)
|
|
1900
|
+
mmr_scores.append(mmr_score)
|
|
1901
|
+
|
|
1902
|
+
# Select document with highest MMR score
|
|
1903
|
+
if mmr_scores:
|
|
1904
|
+
best_remaining_idx = np.argmax(mmr_scores)
|
|
1905
|
+
best_idx = remaining_indices[best_remaining_idx]
|
|
1906
|
+
selected_indices.append(best_idx)
|
|
1907
|
+
remaining_indices.remove(best_idx)
|
|
1908
|
+
|
|
1909
|
+
# Step 3: Return selected results with MMR scores in metadata
|
|
1910
|
+
selected_results = []
|
|
1911
|
+
for i, idx in enumerate(selected_indices):
|
|
1912
|
+
result = candidates[idx]
|
|
1913
|
+
# Add MMR ranking to metadata
|
|
1914
|
+
enhanced_metadata = dict(result.metadata)
|
|
1915
|
+
enhanced_metadata['mmr_rank'] = i + 1
|
|
1916
|
+
enhanced_metadata['mmr_lambda'] = lambda_mult
|
|
1917
|
+
enhanced_metadata['original_rank'] = idx + 1
|
|
1918
|
+
|
|
1919
|
+
enhanced_result = SearchResult(
|
|
1920
|
+
id=result.id,
|
|
1921
|
+
content=result.content,
|
|
1922
|
+
metadata=enhanced_metadata,
|
|
1923
|
+
score=result.score
|
|
1924
|
+
)
|
|
1925
|
+
selected_results.append(enhanced_result)
|
|
1926
|
+
|
|
1927
|
+
return selected_results
|
|
1928
|
+
|
|
1929
|
+
def _compute_similarity(
|
|
1930
|
+
self,
|
|
1931
|
+
embedding1: np.ndarray,
|
|
1932
|
+
embedding2: np.ndarray,
|
|
1933
|
+
metric: Union[str, Any]
|
|
1934
|
+
) -> float:
|
|
1935
|
+
"""
|
|
1936
|
+
Compute similarity between two embeddings based on the specified metric.
|
|
1937
|
+
|
|
1938
|
+
Args:
|
|
1939
|
+
embedding1: First embedding vector (numpy array or list)
|
|
1940
|
+
embedding2: Second embedding vector (numpy array or list)
|
|
1941
|
+
metric: Distance metric ('COSINE', 'L2', 'IP', etc.)
|
|
1942
|
+
|
|
1943
|
+
Returns:
|
|
1944
|
+
Similarity score (higher = more similar)
|
|
1945
|
+
"""
|
|
1946
|
+
# Convert to numpy arrays if needed
|
|
1947
|
+
if isinstance(embedding1, list):
|
|
1948
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
1949
|
+
if isinstance(embedding2, list):
|
|
1950
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
1951
|
+
|
|
1952
|
+
# Ensure embeddings are numpy arrays
|
|
1953
|
+
if not isinstance(embedding1, np.ndarray):
|
|
1954
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
1955
|
+
if not isinstance(embedding2, np.ndarray):
|
|
1956
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
1957
|
+
|
|
1958
|
+
# Ensure embeddings are 2D arrays for sklearn
|
|
1959
|
+
emb1 = embedding1.reshape(1, -1)
|
|
1960
|
+
emb2 = embedding2.reshape(1, -1)
|
|
1961
|
+
|
|
1962
|
+
# Convert string metrics to DistanceStrategy enum if needed
|
|
1963
|
+
if isinstance(metric, str):
|
|
1964
|
+
metric_mapping = {
|
|
1965
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
1966
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
1967
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
1968
|
+
'IP': DistanceStrategy.MAX_INNER_PRODUCT,
|
|
1969
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
1970
|
+
'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
|
|
1971
|
+
'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
|
|
1972
|
+
}
|
|
1973
|
+
strategy = metric_mapping.get(metric.upper(), DistanceStrategy.COSINE)
|
|
1974
|
+
else:
|
|
1975
|
+
strategy = metric
|
|
1976
|
+
|
|
1977
|
+
if strategy == DistanceStrategy.COSINE:
|
|
1978
|
+
# Cosine similarity (returns similarity, not distance)
|
|
1979
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
1980
|
+
return float(similarity)
|
|
1981
|
+
|
|
1982
|
+
elif strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
1983
|
+
# Convert Euclidean distance to similarity
|
|
1984
|
+
distance = euclidean_distances(emb1, emb2)[0, 0]
|
|
1985
|
+
similarity = 1.0 / (1.0 + distance)
|
|
1986
|
+
return float(similarity)
|
|
1987
|
+
|
|
1988
|
+
elif strategy in [DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.DOT_PRODUCT]:
|
|
1989
|
+
# Dot product (inner product)
|
|
1990
|
+
similarity = np.dot(embedding1.flatten(), embedding2.flatten())
|
|
1991
|
+
return float(similarity)
|
|
1992
|
+
|
|
1993
|
+
else:
|
|
1994
|
+
# Default to cosine similarity
|
|
1995
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
1996
|
+
return float(similarity)
|
|
1997
|
+
|
|
1998
|
+
async def delete_documents(
|
|
1999
|
+
self,
|
|
2000
|
+
documents: Optional[List[Document]] = None,
|
|
2001
|
+
pk: str = 'source_type',
|
|
2002
|
+
values: Optional[Union[str, List[str]]] = None,
|
|
2003
|
+
table: Optional[str] = None,
|
|
2004
|
+
schema: Optional[str] = None,
|
|
2005
|
+
metadata_column: Optional[str] = None,
|
|
2006
|
+
**kwargs
|
|
2007
|
+
) -> int:
|
|
2008
|
+
"""
|
|
2009
|
+
Delete documents from the vector store based on metadata field values.
|
|
2010
|
+
|
|
2011
|
+
Args:
|
|
2012
|
+
documents: List of documents whose metadata values will be used for deletion.
|
|
2013
|
+
If provided, the pk field values will be extracted from these documents.
|
|
2014
|
+
pk: The metadata field name to use for deletion (default: 'source_type')
|
|
2015
|
+
values: Specific values to delete. Can be a single string or list of strings.
|
|
2016
|
+
If provided, this takes precedence over extracting from documents.
|
|
2017
|
+
table: Override table name
|
|
2018
|
+
schema: Override schema name
|
|
2019
|
+
metadata_column: Override metadata column name
|
|
2020
|
+
|
|
2021
|
+
Returns:
|
|
2022
|
+
int: Number of documents deleted
|
|
2023
|
+
|
|
2024
|
+
Examples:
|
|
2025
|
+
# Delete all documents with source_type 'papers'
|
|
2026
|
+
deleted_count = await store.delete_documents(values='papers')
|
|
2027
|
+
|
|
2028
|
+
# Delete documents with multiple source types
|
|
2029
|
+
deleted_count = await store.delete_documents(values=['papers', 'reports'])
|
|
2030
|
+
|
|
2031
|
+
# Delete based on documents' metadata
|
|
2032
|
+
docs_to_delete = [Document(page_content="test", metadata={"source_type": "papers"})]
|
|
2033
|
+
deleted_count = await store.delete_documents(documents=docs_to_delete)
|
|
2034
|
+
|
|
2035
|
+
# Delete by different metadata field
|
|
2036
|
+
deleted_count = await store.delete_documents(pk='category', values='obsolete')
|
|
2037
|
+
"""
|
|
2038
|
+
if not self._connected:
|
|
2039
|
+
await self.connection()
|
|
2040
|
+
|
|
2041
|
+
# Use defaults from instance if not provided
|
|
2042
|
+
table = table or self.table_name
|
|
2043
|
+
schema = schema or self.schema
|
|
2044
|
+
metadata_column = metadata_column or self._document_column or 'cmetadata'
|
|
2045
|
+
|
|
2046
|
+
# Extract values to delete
|
|
2047
|
+
delete_values = []
|
|
2048
|
+
|
|
2049
|
+
if values is not None:
|
|
2050
|
+
# Use provided values
|
|
2051
|
+
if isinstance(values, str):
|
|
2052
|
+
delete_values = [values]
|
|
2053
|
+
else:
|
|
2054
|
+
delete_values = list(values)
|
|
2055
|
+
elif documents:
|
|
2056
|
+
# Extract values from documents metadata
|
|
2057
|
+
for doc in documents:
|
|
2058
|
+
if hasattr(doc, 'metadata') and doc.metadata and pk in doc.metadata:
|
|
2059
|
+
value = doc.metadata[pk]
|
|
2060
|
+
if value and value not in delete_values:
|
|
2061
|
+
delete_values.append(value)
|
|
2062
|
+
else:
|
|
2063
|
+
raise ValueError("Either 'documents' or 'values' parameter must be provided")
|
|
2064
|
+
|
|
2065
|
+
if not delete_values:
|
|
2066
|
+
self.logger.warning(f"No values found for field '{pk}' to delete")
|
|
2067
|
+
return 0
|
|
2068
|
+
|
|
2069
|
+
# Construct full table name
|
|
2070
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2071
|
+
|
|
2072
|
+
deleted_count = 0
|
|
2073
|
+
|
|
2074
|
+
try:
|
|
2075
|
+
async with self.session() as session:
|
|
2076
|
+
for value in delete_values:
|
|
2077
|
+
# Use JSONB operator to find matching metadata
|
|
2078
|
+
delete_query = text(f"""
|
|
2079
|
+
DELETE FROM {full_table_name}
|
|
2080
|
+
WHERE {metadata_column}->>:pk = :value
|
|
2081
|
+
""")
|
|
2082
|
+
|
|
2083
|
+
result = await session.execute(
|
|
2084
|
+
delete_query,
|
|
2085
|
+
{"pk": pk, "value": str(value)}
|
|
2086
|
+
)
|
|
2087
|
+
|
|
2088
|
+
rows_deleted = result.rowcount
|
|
2089
|
+
deleted_count += rows_deleted
|
|
2090
|
+
|
|
2091
|
+
self.logger.info(
|
|
2092
|
+
f"Deleted {rows_deleted} documents with {pk}='{value}' from {full_table_name}"
|
|
2093
|
+
)
|
|
2094
|
+
|
|
2095
|
+
self.logger.info(f"Total deleted: {deleted_count} documents")
|
|
2096
|
+
return deleted_count
|
|
2097
|
+
|
|
2098
|
+
except Exception as e:
|
|
2099
|
+
self.logger.error(f"Error deleting documents: {e}")
|
|
2100
|
+
raise RuntimeError(f"Failed to delete documents: {e}") from e
|
|
2101
|
+
|
|
2102
|
+
async def delete_documents_by_filter(
|
|
2103
|
+
self,
|
|
2104
|
+
filter_dict: Dict[str, Union[str, List[str]]],
|
|
2105
|
+
table: Optional[str] = None,
|
|
2106
|
+
schema: Optional[str] = None,
|
|
2107
|
+
metadata_column: Optional[str] = None,
|
|
2108
|
+
**kwargs
|
|
2109
|
+
) -> int:
|
|
2110
|
+
"""
|
|
2111
|
+
Delete documents based on multiple metadata field conditions.
|
|
2112
|
+
|
|
2113
|
+
Args:
|
|
2114
|
+
filter_dict: Dictionary of field_name: value(s) pairs for deletion criteria
|
|
2115
|
+
table: Override table name
|
|
2116
|
+
schema: Override schema name
|
|
2117
|
+
metadata_column: Override metadata column name
|
|
2118
|
+
|
|
2119
|
+
Returns:
|
|
2120
|
+
int: Number of documents deleted
|
|
2121
|
+
|
|
2122
|
+
Example:
|
|
2123
|
+
# Delete documents with source_type='papers' AND category='research'
|
|
2124
|
+
deleted = await store.delete_documents_by_filter({
|
|
2125
|
+
'source_type': 'papers',
|
|
2126
|
+
'category': 'research'
|
|
2127
|
+
})
|
|
2128
|
+
|
|
2129
|
+
# Delete documents with source_type in ['papers', 'reports']
|
|
2130
|
+
deleted = await store.delete_documents_by_filter({
|
|
2131
|
+
'source_type': ['papers', 'reports']
|
|
2132
|
+
})
|
|
2133
|
+
"""
|
|
2134
|
+
if not self._connected:
|
|
2135
|
+
await self.connection()
|
|
2136
|
+
|
|
2137
|
+
if not filter_dict:
|
|
2138
|
+
raise ValueError("filter_dict cannot be empty")
|
|
2139
|
+
|
|
2140
|
+
# Use defaults from instance if not provided
|
|
2141
|
+
table = table or self.table_name
|
|
2142
|
+
schema = schema or self.schema
|
|
2143
|
+
metadata_column = metadata_column or self._document_column or 'cmetadata'
|
|
2144
|
+
|
|
2145
|
+
# Construct full table name
|
|
2146
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2147
|
+
|
|
2148
|
+
# Build WHERE conditions
|
|
2149
|
+
where_conditions = []
|
|
2150
|
+
params = {}
|
|
2151
|
+
|
|
2152
|
+
for field, values in filter_dict.items():
|
|
2153
|
+
if isinstance(values, (list, tuple)):
|
|
2154
|
+
# Handle multiple values with IN operator
|
|
2155
|
+
placeholders = []
|
|
2156
|
+
for i, value in enumerate(values):
|
|
2157
|
+
param_name = f"{field}_{i}"
|
|
2158
|
+
placeholders.append(f":{param_name}")
|
|
2159
|
+
params[param_name] = str(value)
|
|
2160
|
+
|
|
2161
|
+
condition = f"{metadata_column}->>'{field}' IN ({', '.join(placeholders)})"
|
|
2162
|
+
where_conditions.append(condition)
|
|
2163
|
+
else:
|
|
2164
|
+
# Handle single value
|
|
2165
|
+
param_name = f"{field}_single"
|
|
2166
|
+
where_conditions.append(f"{metadata_column}->>'{field}' = :{param_name}")
|
|
2167
|
+
params[param_name] = str(values)
|
|
2168
|
+
|
|
2169
|
+
# Combine conditions with AND
|
|
2170
|
+
where_clause = " AND ".join(where_conditions)
|
|
2171
|
+
|
|
2172
|
+
delete_query = text(f"""
|
|
2173
|
+
DELETE FROM {full_table_name}
|
|
2174
|
+
WHERE {where_clause}
|
|
2175
|
+
""")
|
|
2176
|
+
|
|
2177
|
+
try:
|
|
2178
|
+
async with self.session() as session:
|
|
2179
|
+
result = await session.execute(delete_query, params)
|
|
2180
|
+
deleted_count = result.rowcount
|
|
2181
|
+
|
|
2182
|
+
self.logger.info(
|
|
2183
|
+
f"Deleted {deleted_count} documents from {full_table_name} "
|
|
2184
|
+
f"with filter: {filter_dict}"
|
|
2185
|
+
)
|
|
2186
|
+
|
|
2187
|
+
return deleted_count
|
|
2188
|
+
|
|
2189
|
+
except Exception as e:
|
|
2190
|
+
self.logger.error(f"Error deleting documents by filter: {e}")
|
|
2191
|
+
raise RuntimeError(f"Failed to delete documents by filter: {e}") from e
|
|
2192
|
+
|
|
2193
|
+
async def delete_all_documents(
|
|
2194
|
+
self,
|
|
2195
|
+
table: Optional[str] = None,
|
|
2196
|
+
schema: Optional[str] = None,
|
|
2197
|
+
confirm: bool = False,
|
|
2198
|
+
**kwargs
|
|
2199
|
+
) -> int:
|
|
2200
|
+
"""
|
|
2201
|
+
Delete ALL documents from the vector store table.
|
|
2202
|
+
|
|
2203
|
+
WARNING: This will delete all data in the table!
|
|
2204
|
+
|
|
2205
|
+
Args:
|
|
2206
|
+
table: Override table name
|
|
2207
|
+
schema: Override schema name
|
|
2208
|
+
confirm: Must be set to True to proceed with deletion
|
|
2209
|
+
|
|
2210
|
+
Returns:
|
|
2211
|
+
int: Number of documents deleted
|
|
2212
|
+
"""
|
|
2213
|
+
if not confirm:
|
|
2214
|
+
raise ValueError(
|
|
2215
|
+
"This operation will delete ALL documents. "
|
|
2216
|
+
"Set confirm=True to proceed."
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
if not self._connected:
|
|
2220
|
+
await self.connection()
|
|
2221
|
+
|
|
2222
|
+
# Use defaults from instance if not provided
|
|
2223
|
+
table = table or self.table_name
|
|
2224
|
+
schema = schema or self.schema
|
|
2225
|
+
|
|
2226
|
+
# Construct full table name
|
|
2227
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2228
|
+
|
|
2229
|
+
try:
|
|
2230
|
+
async with self.session() as session:
|
|
2231
|
+
# First count existing documents
|
|
2232
|
+
count_query = text(f"SELECT COUNT(*) FROM {full_table_name}")
|
|
2233
|
+
count_result = await session.execute(count_query)
|
|
2234
|
+
total_docs = count_result.scalar()
|
|
2235
|
+
|
|
2236
|
+
if total_docs == 0:
|
|
2237
|
+
self.logger.info(f"No documents to delete from {full_table_name}")
|
|
2238
|
+
return 0
|
|
2239
|
+
|
|
2240
|
+
# Delete all documents
|
|
2241
|
+
delete_query = text(f"DELETE FROM {full_table_name}")
|
|
2242
|
+
result = await session.execute(delete_query)
|
|
2243
|
+
deleted_count = result.rowcount
|
|
2244
|
+
|
|
2245
|
+
self.logger.warning(
|
|
2246
|
+
f"DELETED ALL {deleted_count} documents from {full_table_name}"
|
|
2247
|
+
)
|
|
2248
|
+
|
|
2249
|
+
return deleted_count
|
|
2250
|
+
|
|
2251
|
+
except Exception as e:
|
|
2252
|
+
self.logger.error(f"Error deleting all documents: {e}")
|
|
2253
|
+
raise RuntimeError(f"Failed to delete all documents: {e}") from e
|
|
2254
|
+
|
|
2255
|
+
async def delete_documents_by_ids(
|
|
2256
|
+
self,
|
|
2257
|
+
document_ids: List[str],
|
|
2258
|
+
table: Optional[str] = None,
|
|
2259
|
+
schema: Optional[str] = None,
|
|
2260
|
+
id_column: Optional[str] = None,
|
|
2261
|
+
**kwargs
|
|
2262
|
+
) -> int:
|
|
2263
|
+
"""
|
|
2264
|
+
Delete documents by their IDs.
|
|
2265
|
+
|
|
2266
|
+
Args:
|
|
2267
|
+
document_ids: List of document IDs to delete
|
|
2268
|
+
table: Override table name
|
|
2269
|
+
schema: Override schema name
|
|
2270
|
+
id_column: Override ID column name
|
|
2271
|
+
|
|
2272
|
+
Returns:
|
|
2273
|
+
int: Number of documents deleted
|
|
2274
|
+
|
|
2275
|
+
Example:
|
|
2276
|
+
deleted_count = await store.delete_documents_by_ids([
|
|
2277
|
+
"doc_1", "doc_2", "doc_3"
|
|
2278
|
+
])
|
|
2279
|
+
"""
|
|
2280
|
+
if not self._connected:
|
|
2281
|
+
await self.connection()
|
|
2282
|
+
|
|
2283
|
+
if not document_ids:
|
|
2284
|
+
self.logger.warning("No document IDs provided for deletion")
|
|
2285
|
+
return 0
|
|
2286
|
+
|
|
2287
|
+
# Use defaults from instance if not provided
|
|
2288
|
+
table = table or self.table_name
|
|
2289
|
+
schema = schema or self.schema
|
|
2290
|
+
id_column = id_column or self._id_column
|
|
2291
|
+
|
|
2292
|
+
# Construct full table name
|
|
2293
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2294
|
+
|
|
2295
|
+
# Build parameterized query for multiple IDs
|
|
2296
|
+
placeholders = []
|
|
2297
|
+
params = {}
|
|
2298
|
+
for i, doc_id in enumerate(document_ids):
|
|
2299
|
+
param_name = f"id_{i}"
|
|
2300
|
+
placeholders.append(f":{param_name}")
|
|
2301
|
+
params[param_name] = str(doc_id)
|
|
2302
|
+
|
|
2303
|
+
delete_query = text(f"""
|
|
2304
|
+
DELETE FROM {full_table_name}
|
|
2305
|
+
WHERE {id_column} IN ({', '.join(placeholders)})
|
|
2306
|
+
""")
|
|
2307
|
+
|
|
2308
|
+
try:
|
|
2309
|
+
async with self.session() as session:
|
|
2310
|
+
result = await session.execute(delete_query, params)
|
|
2311
|
+
deleted_count = result.rowcount
|
|
2312
|
+
|
|
2313
|
+
self.logger.info(
|
|
2314
|
+
f"Deleted {deleted_count} documents by IDs from {full_table_name}"
|
|
2315
|
+
)
|
|
2316
|
+
|
|
2317
|
+
return deleted_count
|
|
2318
|
+
|
|
2319
|
+
except Exception as e:
|
|
2320
|
+
self.logger.error(f"Error deleting documents by IDs: {e}")
|
|
2321
|
+
raise RuntimeError(f"Failed to delete documents by IDs: {e}") from e
|
|
2322
|
+
|
|
2323
|
+
# Additional utility method for safer deletions
|
|
2324
|
+
async def count_documents_by_filter(
|
|
2325
|
+
self,
|
|
2326
|
+
filter_dict: Dict[str, Union[str, List[str]]],
|
|
2327
|
+
table: Optional[str] = None,
|
|
2328
|
+
schema: Optional[str] = None,
|
|
2329
|
+
metadata_column: Optional[str] = None,
|
|
2330
|
+
**kwargs
|
|
2331
|
+
) -> int:
|
|
2332
|
+
"""
|
|
2333
|
+
Count documents that would be affected by a filter (useful before deletion).
|
|
2334
|
+
|
|
2335
|
+
Args:
|
|
2336
|
+
filter_dict: Dictionary of field_name: value(s) pairs for criteria
|
|
2337
|
+
table: Override table name
|
|
2338
|
+
schema: Override schema name
|
|
2339
|
+
metadata_column: Override metadata column name
|
|
2340
|
+
|
|
2341
|
+
Returns:
|
|
2342
|
+
int: Number of documents matching the filter
|
|
2343
|
+
"""
|
|
2344
|
+
if not self._connected:
|
|
2345
|
+
await self.connection()
|
|
2346
|
+
|
|
2347
|
+
if not filter_dict:
|
|
2348
|
+
return 0
|
|
2349
|
+
|
|
2350
|
+
# Use defaults from instance if not provided
|
|
2351
|
+
table = table or self.table_name
|
|
2352
|
+
schema = schema or self.schema
|
|
2353
|
+
metadata_column = metadata_column or self._document_column or 'cmetadata'
|
|
2354
|
+
|
|
2355
|
+
# Construct full table name
|
|
2356
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2357
|
+
|
|
2358
|
+
# Build WHERE conditions (same logic as delete_documents_by_filter)
|
|
2359
|
+
where_conditions = []
|
|
2360
|
+
params = {}
|
|
2361
|
+
|
|
2362
|
+
for field, values in filter_dict.items():
|
|
2363
|
+
if isinstance(values, (list, tuple)):
|
|
2364
|
+
placeholders = []
|
|
2365
|
+
for i, value in enumerate(values):
|
|
2366
|
+
param_name = f"{field}_{i}"
|
|
2367
|
+
placeholders.append(f":{param_name}")
|
|
2368
|
+
params[param_name] = str(value)
|
|
2369
|
+
|
|
2370
|
+
condition = f"{metadata_column}->>'{field}' IN ({', '.join(placeholders)})"
|
|
2371
|
+
where_conditions.append(condition)
|
|
2372
|
+
else:
|
|
2373
|
+
param_name = f"{field}_single"
|
|
2374
|
+
where_conditions.append(f"{metadata_column}->>'{field}' = :{param_name}")
|
|
2375
|
+
params[param_name] = str(values)
|
|
2376
|
+
|
|
2377
|
+
where_clause = " AND ".join(where_conditions)
|
|
2378
|
+
count_query = text(f"""
|
|
2379
|
+
SELECT COUNT(*) FROM {full_table_name}
|
|
2380
|
+
WHERE {where_clause}
|
|
2381
|
+
""")
|
|
2382
|
+
|
|
2383
|
+
try:
|
|
2384
|
+
async with self.session() as session:
|
|
2385
|
+
result = await session.execute(count_query, params)
|
|
2386
|
+
count = result.scalar()
|
|
2387
|
+
|
|
2388
|
+
self.logger.info(
|
|
2389
|
+
f"Found {count} documents matching filter: {filter_dict}"
|
|
2390
|
+
)
|
|
2391
|
+
|
|
2392
|
+
return count
|
|
2393
|
+
|
|
2394
|
+
except Exception as e:
|
|
2395
|
+
self.logger.error(f"Error counting documents: {e}")
|
|
2396
|
+
raise RuntimeError(f"Failed to count documents: {e}") from e
|
|
2397
|
+
|
|
2398
|
+
async def from_documents(
|
|
2399
|
+
self,
|
|
2400
|
+
documents: List[Document],
|
|
2401
|
+
table: str = None,
|
|
2402
|
+
schema: str = None,
|
|
2403
|
+
embedding_column: str = 'embedding',
|
|
2404
|
+
content_column: str = 'document',
|
|
2405
|
+
metadata_column: str = 'cmetadata',
|
|
2406
|
+
chunk_size: int = 8192,
|
|
2407
|
+
chunk_overlap: int = 200,
|
|
2408
|
+
store_full_document: bool = True,
|
|
2409
|
+
**kwargs
|
|
2410
|
+
) -> Dict[str, Any]:
|
|
2411
|
+
"""
|
|
2412
|
+
Add documents using late chunking strategy.
|
|
2413
|
+
|
|
2414
|
+
Args:
|
|
2415
|
+
documents: List of Document objects to process
|
|
2416
|
+
table: Table name
|
|
2417
|
+
schema: Schema name
|
|
2418
|
+
embedding_column: Name of embedding column
|
|
2419
|
+
content_column: Name of content column
|
|
2420
|
+
metadata_column: Name of metadata column
|
|
2421
|
+
chunk_size: Maximum size of each chunk
|
|
2422
|
+
chunk_overlap: Overlap between chunks
|
|
2423
|
+
store_full_document: Whether to store full document alongside chunks
|
|
2424
|
+
|
|
2425
|
+
Returns:
|
|
2426
|
+
Dictionary with processing statistics
|
|
2427
|
+
"""
|
|
2428
|
+
if not self._connected:
|
|
2429
|
+
await self.connection()
|
|
2430
|
+
|
|
2431
|
+
table = table or self.table_name
|
|
2432
|
+
schema = schema or self.schema
|
|
2433
|
+
|
|
2434
|
+
|
|
2435
|
+
# Initialize late chunking processor
|
|
2436
|
+
chunking_processor = LateChunkingProcessor(
|
|
2437
|
+
vector_store=self,
|
|
2438
|
+
chunk_size=chunk_size,
|
|
2439
|
+
chunk_overlap=chunk_overlap
|
|
2440
|
+
)
|
|
2441
|
+
|
|
2442
|
+
# Ensure embedding store is initialized
|
|
2443
|
+
if self.embedding_store is None:
|
|
2444
|
+
self.embedding_store = self._define_collection_store(
|
|
2445
|
+
table=table,
|
|
2446
|
+
schema=schema,
|
|
2447
|
+
dimension=self.dimension,
|
|
2448
|
+
id_column=self._id_column,
|
|
2449
|
+
embedding_column=embedding_column,
|
|
2450
|
+
document_column=content_column,
|
|
2451
|
+
metadata_column=metadata_column,
|
|
2452
|
+
text_column=self._text_column,
|
|
2453
|
+
)
|
|
2454
|
+
|
|
2455
|
+
all_inserts = []
|
|
2456
|
+
stats = {
|
|
2457
|
+
'documents_processed': 0,
|
|
2458
|
+
'chunks_created': 0,
|
|
2459
|
+
'full_documents_stored': 0
|
|
2460
|
+
}
|
|
2461
|
+
for doc_idx, document in enumerate(documents):
|
|
2462
|
+
document_id = f"doc_{doc_idx:06d}_{uuid.uuid4().hex[:8]}"
|
|
2463
|
+
|
|
2464
|
+
# Process document with late chunking
|
|
2465
|
+
full_embedding, chunk_infos = await chunking_processor.process_document_late_chunking(
|
|
2466
|
+
document_text=document.page_content,
|
|
2467
|
+
document_id=document_id,
|
|
2468
|
+
metadata=document.metadata
|
|
2469
|
+
)
|
|
2470
|
+
# Store full document if requested
|
|
2471
|
+
if store_full_document:
|
|
2472
|
+
full_doc_metadata = {
|
|
2473
|
+
**(document.metadata or {}),
|
|
2474
|
+
'document_id': document_id,
|
|
2475
|
+
'is_full_document': True,
|
|
2476
|
+
'total_chunks': len(chunk_infos),
|
|
2477
|
+
'document_type': 'parent',
|
|
2478
|
+
'chunking_strategy': 'late_chunking'
|
|
2479
|
+
}
|
|
2480
|
+
|
|
2481
|
+
all_inserts.append({
|
|
2482
|
+
self._id_column: document_id,
|
|
2483
|
+
embedding_column: full_embedding.tolist(),
|
|
2484
|
+
content_column: document.page_content,
|
|
2485
|
+
metadata_column: full_doc_metadata
|
|
2486
|
+
})
|
|
2487
|
+
stats['full_documents_stored'] += 1
|
|
2488
|
+
|
|
2489
|
+
# Store all chunks
|
|
2490
|
+
for chunk_info in chunk_infos:
|
|
2491
|
+
embed = chunk_info.chunk_embedding if isinstance(chunk_info.chunk_embedding, list) else chunk_info.chunk_embedding.tolist()
|
|
2492
|
+
all_inserts.append({
|
|
2493
|
+
self._id_column: chunk_info.chunk_id,
|
|
2494
|
+
embedding_column: embed,
|
|
2495
|
+
content_column: chunk_info.chunk_text,
|
|
2496
|
+
metadata_column: chunk_info.metadata
|
|
2497
|
+
})
|
|
2498
|
+
stats['chunks_created'] += 1
|
|
2499
|
+
|
|
2500
|
+
stats['documents_processed'] += 1
|
|
2501
|
+
|
|
2502
|
+
# Bulk insert all data
|
|
2503
|
+
if all_inserts:
|
|
2504
|
+
insert_stmt = insert(self.embedding_store)
|
|
2505
|
+
|
|
2506
|
+
try:
|
|
2507
|
+
async with self.session() as session:
|
|
2508
|
+
await session.execute(insert_stmt, all_inserts)
|
|
2509
|
+
|
|
2510
|
+
self.logger.info(
|
|
2511
|
+
f"✅ Late chunking complete: {stats['documents_processed']} documents → "
|
|
2512
|
+
f"{stats['chunks_created']} chunks + {stats['full_documents_stored']} full docs"
|
|
2513
|
+
)
|
|
2514
|
+
|
|
2515
|
+
except Exception as e:
|
|
2516
|
+
self.logger.error(f"Error in late chunking insert: {e}")
|
|
2517
|
+
raise
|
|
2518
|
+
|
|
2519
|
+
return stats
|
|
2520
|
+
|
|
2521
|
+
async def document_search(
|
|
2522
|
+
self,
|
|
2523
|
+
query: str,
|
|
2524
|
+
table: str = None,
|
|
2525
|
+
schema: str = None,
|
|
2526
|
+
limit: int = 10,
|
|
2527
|
+
search_chunks: bool = True,
|
|
2528
|
+
search_full_docs: bool = False,
|
|
2529
|
+
rerank_with_context: bool = True,
|
|
2530
|
+
context_window: int = 2,
|
|
2531
|
+
**kwargs
|
|
2532
|
+
) -> List[SearchResult]:
|
|
2533
|
+
"""
|
|
2534
|
+
Search with late chunking context awareness.
|
|
2535
|
+
|
|
2536
|
+
Args:
|
|
2537
|
+
query: Search query
|
|
2538
|
+
table: Table name
|
|
2539
|
+
schema: Schema name
|
|
2540
|
+
limit: Number of results
|
|
2541
|
+
search_chunks: Whether to search chunk-level embeddings
|
|
2542
|
+
search_full_docs: Whether to search full document embeddings
|
|
2543
|
+
rerank_with_context: Whether to rerank using surrounding chunks
|
|
2544
|
+
context_window: Number of adjacent chunks to include for context
|
|
2545
|
+
|
|
2546
|
+
Returns:
|
|
2547
|
+
List of SearchResult objects with enhanced context
|
|
2548
|
+
"""
|
|
2549
|
+
results = []
|
|
2550
|
+
|
|
2551
|
+
# Search chunks if requested
|
|
2552
|
+
if search_chunks:
|
|
2553
|
+
chunk_filters = {'is_chunk': True}
|
|
2554
|
+
chunk_results = await self.similarity_search(
|
|
2555
|
+
query=query,
|
|
2556
|
+
table=table,
|
|
2557
|
+
schema=schema,
|
|
2558
|
+
limit=limit * 2, # Get more candidates for reranking
|
|
2559
|
+
metadata_filters=chunk_filters,
|
|
2560
|
+
**kwargs
|
|
2561
|
+
)
|
|
2562
|
+
results.extend(chunk_results)
|
|
2563
|
+
|
|
2564
|
+
# Search full documents if requested
|
|
2565
|
+
if search_full_docs:
|
|
2566
|
+
doc_filters = {'is_full_document': True}
|
|
2567
|
+
doc_results = await self.similarity_search(
|
|
2568
|
+
query=query,
|
|
2569
|
+
table=table,
|
|
2570
|
+
schema=schema,
|
|
2571
|
+
limit=limit,
|
|
2572
|
+
metadata_filters=doc_filters,
|
|
2573
|
+
**kwargs
|
|
2574
|
+
)
|
|
2575
|
+
results.extend(doc_results)
|
|
2576
|
+
|
|
2577
|
+
# Rerank with context if requested
|
|
2578
|
+
if rerank_with_context and search_chunks:
|
|
2579
|
+
results = await self._rerank_with_chunk_context(
|
|
2580
|
+
results=results,
|
|
2581
|
+
query=query,
|
|
2582
|
+
table=table,
|
|
2583
|
+
schema=schema,
|
|
2584
|
+
context_window=context_window
|
|
2585
|
+
)
|
|
2586
|
+
|
|
2587
|
+
# Sort by score and limit
|
|
2588
|
+
results.sort(key=lambda x: x.score)
|
|
2589
|
+
return results[:limit]
|
|
2590
|
+
|
|
2591
|
+
async def _rerank_with_chunk_context(
|
|
2592
|
+
self,
|
|
2593
|
+
results: List[SearchResult],
|
|
2594
|
+
query: str,
|
|
2595
|
+
table: str,
|
|
2596
|
+
schema: str,
|
|
2597
|
+
context_window: int = 2
|
|
2598
|
+
) -> List[SearchResult]:
|
|
2599
|
+
"""
|
|
2600
|
+
Rerank results by including surrounding chunk context.
|
|
2601
|
+
"""
|
|
2602
|
+
enhanced_results = []
|
|
2603
|
+
|
|
2604
|
+
for result in results:
|
|
2605
|
+
if not result.metadata.get('is_chunk'):
|
|
2606
|
+
enhanced_results.append(result)
|
|
2607
|
+
continue
|
|
2608
|
+
|
|
2609
|
+
# Get surrounding chunks for context
|
|
2610
|
+
parent_id = result.metadata.get('parent_document_id')
|
|
2611
|
+
chunk_index = result.metadata.get('chunk_index', 0)
|
|
2612
|
+
|
|
2613
|
+
if parent_id:
|
|
2614
|
+
try:
|
|
2615
|
+
# Find adjacent chunks
|
|
2616
|
+
context_chunks = await self._get_adjacent_chunks(
|
|
2617
|
+
parent_id=parent_id,
|
|
2618
|
+
center_chunk_index=chunk_index,
|
|
2619
|
+
window_size=context_window,
|
|
2620
|
+
table=table,
|
|
2621
|
+
schema=schema
|
|
2622
|
+
)
|
|
2623
|
+
|
|
2624
|
+
# Combine text with context
|
|
2625
|
+
combined_text = self._combine_chunk_context(result, context_chunks)
|
|
2626
|
+
|
|
2627
|
+
# Re-score with context - ensure embeddings are consistent
|
|
2628
|
+
context_embedding = self._embed_.embed_query(combined_text)
|
|
2629
|
+
query_embedding = self._embed_.embed_query(query)
|
|
2630
|
+
|
|
2631
|
+
# Ensure both embeddings are numpy arrays
|
|
2632
|
+
if isinstance(context_embedding, list):
|
|
2633
|
+
context_embedding = np.array(context_embedding, dtype=np.float32)
|
|
2634
|
+
if isinstance(query_embedding, list):
|
|
2635
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
2636
|
+
|
|
2637
|
+
# Calculate new similarity score
|
|
2638
|
+
context_score = self._compute_similarity(
|
|
2639
|
+
query_embedding, context_embedding, self.distance_strategy
|
|
2640
|
+
)
|
|
2641
|
+
|
|
2642
|
+
# Create enhanced result
|
|
2643
|
+
enhanced_metadata = dict(result.metadata)
|
|
2644
|
+
enhanced_metadata['context_score'] = context_score
|
|
2645
|
+
enhanced_metadata['has_context'] = True
|
|
2646
|
+
enhanced_metadata['context_chunks'] = len(context_chunks)
|
|
2647
|
+
|
|
2648
|
+
enhanced_result = SearchResult(
|
|
2649
|
+
id=result.id,
|
|
2650
|
+
content=combined_text,
|
|
2651
|
+
metadata=enhanced_metadata,
|
|
2652
|
+
score=context_score
|
|
2653
|
+
)
|
|
2654
|
+
|
|
2655
|
+
enhanced_results.append(enhanced_result)
|
|
2656
|
+
|
|
2657
|
+
except Exception as e:
|
|
2658
|
+
self.logger.warning(f"Error reranking chunk {result.id}: {e}")
|
|
2659
|
+
# Fall back to original result if reranking fails
|
|
2660
|
+
enhanced_results.append(result)
|
|
2661
|
+
else:
|
|
2662
|
+
enhanced_results.append(result)
|
|
2663
|
+
|
|
2664
|
+
return enhanced_results
|
|
2665
|
+
|
|
2666
|
+
async def _get_adjacent_chunks(
|
|
2667
|
+
self,
|
|
2668
|
+
parent_id: str,
|
|
2669
|
+
center_chunk_index: int,
|
|
2670
|
+
window_size: int,
|
|
2671
|
+
table: str,
|
|
2672
|
+
schema: str
|
|
2673
|
+
) -> List[SearchResult]:
|
|
2674
|
+
"""Get adjacent chunks for context."""
|
|
2675
|
+
# Calculate chunk index range
|
|
2676
|
+
start_idx = max(0, center_chunk_index - window_size)
|
|
2677
|
+
end_idx = center_chunk_index + window_size + 1
|
|
2678
|
+
|
|
2679
|
+
# Search for chunks in the range
|
|
2680
|
+
chunk_filters = {
|
|
2681
|
+
'parent_document_id': parent_id,
|
|
2682
|
+
'is_chunk': True
|
|
2683
|
+
}
|
|
2684
|
+
|
|
2685
|
+
# Get all chunks from parent document
|
|
2686
|
+
all_chunks = await self.similarity_search(
|
|
2687
|
+
query="dummy",
|
|
2688
|
+
table=table,
|
|
2689
|
+
schema=schema,
|
|
2690
|
+
limit=1000, # High limit to get all chunks
|
|
2691
|
+
metadata_filters=chunk_filters
|
|
2692
|
+
)
|
|
2693
|
+
|
|
2694
|
+
# Filter to adjacent chunks
|
|
2695
|
+
adjacent_chunks = [
|
|
2696
|
+
chunk for chunk in all_chunks
|
|
2697
|
+
if start_idx <= chunk.metadata.get('chunk_index', 0) < end_idx
|
|
2698
|
+
]
|
|
2699
|
+
|
|
2700
|
+
# Sort by chunk index
|
|
2701
|
+
adjacent_chunks.sort(key=lambda x: x.metadata.get('chunk_index', 0))
|
|
2702
|
+
|
|
2703
|
+
return adjacent_chunks
|
|
2704
|
+
|
|
2705
|
+
def _combine_chunk_context(
|
|
2706
|
+
self,
|
|
2707
|
+
center_result: SearchResult,
|
|
2708
|
+
context_chunks: List[SearchResult]
|
|
2709
|
+
) -> str:
|
|
2710
|
+
"""Combine center chunk with surrounding context."""
|
|
2711
|
+
# Sort context chunks by index
|
|
2712
|
+
context_chunks.sort(key=lambda x: x.metadata.get('chunk_index', 0))
|
|
2713
|
+
|
|
2714
|
+
# Combine text
|
|
2715
|
+
combined_parts = []
|
|
2716
|
+
center_idx = center_result.metadata.get('chunk_index', 0)
|
|
2717
|
+
|
|
2718
|
+
for chunk in context_chunks:
|
|
2719
|
+
chunk_idx = chunk.metadata.get('chunk_index', 0)
|
|
2720
|
+
if chunk_idx == center_idx:
|
|
2721
|
+
# Mark the main chunk
|
|
2722
|
+
combined_parts.append(f"[MAIN] {chunk.content} [/MAIN]")
|
|
2723
|
+
else:
|
|
2724
|
+
combined_parts.append(chunk.content)
|
|
2725
|
+
|
|
2726
|
+
return " ... ".join(combined_parts)
|
|
2727
|
+
|
|
2728
|
+
async def collection_exists(self, table: str, schema: str = 'public') -> bool:
|
|
2729
|
+
"""
|
|
2730
|
+
Check if a collection (table) exists in the database.
|
|
2731
|
+
|
|
2732
|
+
Args:
|
|
2733
|
+
table: Name of the table to check
|
|
2734
|
+
schema: Schema name (default: 'public')
|
|
2735
|
+
|
|
2736
|
+
Returns:
|
|
2737
|
+
bool: True if the collection exists, False otherwise
|
|
2738
|
+
"""
|
|
2739
|
+
if not self._connected:
|
|
2740
|
+
await self.connection()
|
|
2741
|
+
|
|
2742
|
+
async with self.session() as session:
|
|
2743
|
+
query = text(f"""
|
|
2744
|
+
SELECT EXISTS (
|
|
2745
|
+
SELECT 1 FROM information_schema.tables
|
|
2746
|
+
WHERE table_schema = :schema AND table_name = :table
|
|
2747
|
+
)
|
|
2748
|
+
""")
|
|
2749
|
+
result = await session.execute(query, {"schema": schema, "table": table})
|
|
2750
|
+
return result.scalar()
|
|
2751
|
+
return False
|
|
2752
|
+
|
|
2753
|
+
async def delete_collection(
|
|
2754
|
+
self,
|
|
2755
|
+
table: str,
|
|
2756
|
+
schema: str = 'public'
|
|
2757
|
+
) -> None:
|
|
2758
|
+
"""
|
|
2759
|
+
Delete a collection (table) from the database.
|
|
2760
|
+
|
|
2761
|
+
Args:
|
|
2762
|
+
table: Name of the table to delete
|
|
2763
|
+
schema: Schema name (default: 'public')
|
|
2764
|
+
|
|
2765
|
+
Raises:
|
|
2766
|
+
RuntimeError: If the collection does not exist or deletion fails
|
|
2767
|
+
"""
|
|
2768
|
+
if not self._connected:
|
|
2769
|
+
await self.connection()
|
|
2770
|
+
|
|
2771
|
+
if not await self.collection_exists(table, schema):
|
|
2772
|
+
raise RuntimeError(
|
|
2773
|
+
f"Collection {schema}.{table} does not exist"
|
|
2774
|
+
)
|
|
2775
|
+
|
|
2776
|
+
async with self.session() as session:
|
|
2777
|
+
query = text(
|
|
2778
|
+
f"DROP TABLE IF EXISTS {schema}.{table} CASCADE"
|
|
2779
|
+
)
|
|
2780
|
+
await session.execute(query)
|
|
2781
|
+
self.logger.info(
|
|
2782
|
+
f"Collection {schema}.{table} deleted successfully"
|
|
2783
|
+
)
|
|
2784
|
+
|
|
2785
|
+
async def create_collection(
|
|
2786
|
+
self,
|
|
2787
|
+
table: str,
|
|
2788
|
+
schema: str = 'public',
|
|
2789
|
+
dimension: int = 768,
|
|
2790
|
+
index_type: str = "COSINE",
|
|
2791
|
+
metric_type: str = 'L2',
|
|
2792
|
+
id_column: Optional[str] = None,
|
|
2793
|
+
**kwargs
|
|
2794
|
+
) -> None:
|
|
2795
|
+
"""
|
|
2796
|
+
Create a new collection (table) in the database.
|
|
2797
|
+
|
|
2798
|
+
Args:
|
|
2799
|
+
table: Name of the table to create
|
|
2800
|
+
schema: Schema name (default: 'public')
|
|
2801
|
+
dimension: Embedding dimension (default: 768)
|
|
2802
|
+
index_type: Type of index to create (default: "COSINE")
|
|
2803
|
+
metric_type: Distance metric type (default: 'L2')
|
|
2804
|
+
id_column: Name of the ID column (default: 'id')
|
|
2805
|
+
embedding_column: Name of the embedding column (default: 'embedding')
|
|
2806
|
+
document_column: Name of the document content column (default: 'document')
|
|
2807
|
+
metadata_column: Name of the metadata column (default: 'cmetadata')
|
|
2808
|
+
|
|
2809
|
+
Raises:
|
|
2810
|
+
RuntimeError: If collection creation fails
|
|
2811
|
+
"""
|
|
2812
|
+
if not self._connected:
|
|
2813
|
+
await self.connection()
|
|
2814
|
+
|
|
2815
|
+
# Construct full table name
|
|
2816
|
+
full_table_name = f"{schema}.{table}" if schema != 'public' else table
|
|
2817
|
+
self._metric_type: str = metric_type.upper()
|
|
2818
|
+
self._index_type: str = index_type.upper()
|
|
2819
|
+
try:
|
|
2820
|
+
async with self.session() as session:
|
|
2821
|
+
# Check if collection already exists
|
|
2822
|
+
if await self.collection_exists(table, schema):
|
|
2823
|
+
self.logger.info(
|
|
2824
|
+
f"Collection {schema}.{table} already exists"
|
|
2825
|
+
)
|
|
2826
|
+
else:
|
|
2827
|
+
id_column = id_column or self._id_column or 'id'
|
|
2828
|
+
# Create the collection:
|
|
2829
|
+
self.logger.info(f"Creating collection {schema}.{table}...")
|
|
2830
|
+
create_query = text(f"""
|
|
2831
|
+
CREATE TABLE {full_table_name} (
|
|
2832
|
+
{id_column} TEXT PRIMARY KEY
|
|
2833
|
+
)
|
|
2834
|
+
""")
|
|
2835
|
+
await session.execute(create_query)
|
|
2836
|
+
self.logger.info(
|
|
2837
|
+
f"Collection {schema}.{table} created successfully"
|
|
2838
|
+
)
|
|
2839
|
+
# Execute prepare:
|
|
2840
|
+
await self.prepare_embedding_table(
|
|
2841
|
+
table=table,
|
|
2842
|
+
schema=schema,
|
|
2843
|
+
conn=session,
|
|
2844
|
+
dimension=dimension,
|
|
2845
|
+
id_column=id_column,
|
|
2846
|
+
create_all_indexes=True,
|
|
2847
|
+
**kwargs
|
|
2848
|
+
)
|
|
2849
|
+
except Exception as e:
|
|
2850
|
+
self.logger.error(f"Error creating collection: {e}")
|
|
2851
|
+
raise RuntimeError(
|
|
2852
|
+
f"Failed to create collection: {e}"
|
|
2853
|
+
) from e
|