ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,1157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FAISSStore: In-memory Vector Store implementation using FAISS.
|
|
3
|
+
|
|
4
|
+
Provides high-performance vector similarity search with:
|
|
5
|
+
- In-memory vector storage with FAISS indexes
|
|
6
|
+
- Multiple distance metrics (Cosine, L2, Inner Product)
|
|
7
|
+
- CPU-only execution (GPU support removed)
|
|
8
|
+
- MMR (Maximal Marginal Relevance) search
|
|
9
|
+
- Metadata filtering
|
|
10
|
+
- Collection management
|
|
11
|
+
- Async context manager support
|
|
12
|
+
"""
|
|
13
|
+
from typing import Any, Dict, List, Optional, Union, Callable
|
|
14
|
+
import uuid
|
|
15
|
+
import pickle
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
import numpy as np
|
|
18
|
+
from navconfig.logging import logging
|
|
19
|
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import faiss
|
|
23
|
+
FAISS_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
FAISS_AVAILABLE = False
|
|
26
|
+
faiss = None
|
|
27
|
+
|
|
28
|
+
from .abstract import AbstractStore
|
|
29
|
+
from .models import Document, SearchResult, DistanceStrategy
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FAISSStore(AbstractStore):
|
|
33
|
+
"""
|
|
34
|
+
An in-memory FAISS vector store implementation, completely independent of Langchain.
|
|
35
|
+
|
|
36
|
+
This store provides high-performance vector similarity search using FAISS indexes
|
|
37
|
+
with support for multiple distance metrics and metadata filtering.
|
|
38
|
+
|
|
39
|
+
Features:
|
|
40
|
+
- Multiple FAISS index types (Flat, IVF, HNSW)
|
|
41
|
+
- CPU-only execution
|
|
42
|
+
- Cosine, L2, and Inner Product distance metrics
|
|
43
|
+
- MMR (Maximal Marginal Relevance) search
|
|
44
|
+
- Metadata filtering
|
|
45
|
+
- Persistent storage via save/load
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
collection_name: str = "default",
|
|
51
|
+
id_column: str = 'id',
|
|
52
|
+
embedding_column: str = 'embedding',
|
|
53
|
+
document_column: str = 'document',
|
|
54
|
+
text_column: str = 'text',
|
|
55
|
+
metadata_column: str = 'metadata',
|
|
56
|
+
embedding_model: Union[dict, str] = "sentence-transformers/all-mpnet-base-v2",
|
|
57
|
+
embedding: Optional[Callable] = None,
|
|
58
|
+
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
|
59
|
+
index_type: str = "Flat", # Options: "Flat", "IVF", "HNSW"
|
|
60
|
+
nlist: int = 100, # For IVF indexes
|
|
61
|
+
nprobe: int = 10, # For IVF search
|
|
62
|
+
m: int = 32, # For HNSW
|
|
63
|
+
ef_construction: int = 40, # For HNSW
|
|
64
|
+
ef_search: int = 16, # For HNSW
|
|
65
|
+
**kwargs
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize FAISSStore with the specified parameters.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
collection_name: Name of the collection/index
|
|
72
|
+
id_column: Name of the ID column
|
|
73
|
+
embedding_column: Name of the embedding column
|
|
74
|
+
document_column: Name of the document content column
|
|
75
|
+
text_column: Name of the text column
|
|
76
|
+
metadata_column: Name of the metadata column
|
|
77
|
+
embedding_model: Embedding model configuration (dict or string)
|
|
78
|
+
embedding: Custom embedding function
|
|
79
|
+
distance_strategy: Distance metric to use (COSINE, EUCLIDEAN_DISTANCE, etc.)
|
|
80
|
+
index_type: Type of FAISS index ("Flat", "IVF", "HNSW")
|
|
81
|
+
nlist: Number of clusters for IVF indexes
|
|
82
|
+
nprobe: Number of clusters to probe for IVF search
|
|
83
|
+
m: Number of connections per layer for HNSW
|
|
84
|
+
ef_construction: Size of dynamic candidate list for HNSW construction
|
|
85
|
+
ef_search: Size of dynamic candidate list for HNSW search
|
|
86
|
+
"""
|
|
87
|
+
if not FAISS_AVAILABLE:
|
|
88
|
+
raise ImportError(
|
|
89
|
+
"FAISS is not installed. Please install it with: pip install faiss-cpu"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Store configuration
|
|
93
|
+
self.collection_name = collection_name
|
|
94
|
+
self._id_column: str = id_column
|
|
95
|
+
self._embedding_column: str = embedding_column
|
|
96
|
+
self._document_column: str = document_column
|
|
97
|
+
self._text_column: str = text_column
|
|
98
|
+
self._metadata_column: str = metadata_column
|
|
99
|
+
|
|
100
|
+
# FAISS configuration
|
|
101
|
+
self.index_type = index_type
|
|
102
|
+
self.nlist = nlist
|
|
103
|
+
self.nprobe = nprobe
|
|
104
|
+
self.m = m
|
|
105
|
+
self.ef_construction = ef_construction
|
|
106
|
+
self.ef_search = ef_search
|
|
107
|
+
|
|
108
|
+
# Distance strategy - normalize to enum
|
|
109
|
+
if isinstance(distance_strategy, str):
|
|
110
|
+
# Convert string to DistanceStrategy enum
|
|
111
|
+
try:
|
|
112
|
+
self.distance_strategy = DistanceStrategy[distance_strategy.upper()]
|
|
113
|
+
except KeyError:
|
|
114
|
+
self.logger.warning(
|
|
115
|
+
f"Unknown distance strategy '{distance_strategy}', using COSINE"
|
|
116
|
+
)
|
|
117
|
+
self.distance_strategy = DistanceStrategy.COSINE
|
|
118
|
+
elif isinstance(distance_strategy, DistanceStrategy):
|
|
119
|
+
self.distance_strategy = distance_strategy
|
|
120
|
+
else:
|
|
121
|
+
# Default to COSINE if invalid type
|
|
122
|
+
self.logger.warning(
|
|
123
|
+
f"Invalid distance_strategy type: {type(distance_strategy)}, using COSINE"
|
|
124
|
+
)
|
|
125
|
+
self.distance_strategy = DistanceStrategy.COSINE
|
|
126
|
+
|
|
127
|
+
# Initialize parent class
|
|
128
|
+
super().__init__(
|
|
129
|
+
embedding_model=embedding_model,
|
|
130
|
+
embedding=embedding,
|
|
131
|
+
**kwargs
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Collections store: {collection_name: collection_data}
|
|
135
|
+
self._collections: Dict[str, Dict[str, Any]] = {}
|
|
136
|
+
|
|
137
|
+
# Initialize logger
|
|
138
|
+
self.logger = logging.getLogger("FAISSStore")
|
|
139
|
+
|
|
140
|
+
# Connection state
|
|
141
|
+
self._connected: bool = False
|
|
142
|
+
self._connection = None # For compatibility with abstract interface
|
|
143
|
+
|
|
144
|
+
# Initialize default collection
|
|
145
|
+
if collection_name:
|
|
146
|
+
self._initialize_collection(collection_name)
|
|
147
|
+
|
|
148
|
+
def _initialize_collection(self, collection_name: str) -> None:
|
|
149
|
+
"""Initialize a new collection with empty data structures."""
|
|
150
|
+
if collection_name not in self._collections:
|
|
151
|
+
self._collections[collection_name] = {
|
|
152
|
+
'index': None, # FAISS index
|
|
153
|
+
'documents': {}, # {id: document_content}
|
|
154
|
+
'metadata': {}, # {id: metadata_dict}
|
|
155
|
+
'embeddings': {}, # {id: embedding_vector}
|
|
156
|
+
'id_to_idx': {}, # {id: faiss_index_position}
|
|
157
|
+
'idx_to_id': {}, # {faiss_index_position: id}
|
|
158
|
+
'dimension': None,
|
|
159
|
+
'is_trained': False,
|
|
160
|
+
}
|
|
161
|
+
self.logger.info(f"Initialized collection: {collection_name}")
|
|
162
|
+
|
|
163
|
+
def define_collection_table(
|
|
164
|
+
self,
|
|
165
|
+
collection_name: str,
|
|
166
|
+
dimension: int = 384,
|
|
167
|
+
**kwargs
|
|
168
|
+
) -> Dict[str, Any]:
|
|
169
|
+
"""
|
|
170
|
+
Define an in-memory collection table for saving vector + metadata information.
|
|
171
|
+
|
|
172
|
+
This method is compatible with the PgVectorStore pattern but operates in-memory.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
collection_name: Name of the collection
|
|
176
|
+
dimension: Dimension of the embedding vectors
|
|
177
|
+
**kwargs: Additional arguments (for compatibility)
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dictionary representing the collection structure
|
|
181
|
+
"""
|
|
182
|
+
if collection_name not in self._collections:
|
|
183
|
+
self._initialize_collection(collection_name)
|
|
184
|
+
|
|
185
|
+
collection = self._collections[collection_name]
|
|
186
|
+
collection['dimension'] = dimension
|
|
187
|
+
|
|
188
|
+
# Create FAISS index based on configuration
|
|
189
|
+
index = self._create_faiss_index(dimension)
|
|
190
|
+
collection['index'] = index
|
|
191
|
+
|
|
192
|
+
self.logger.info(
|
|
193
|
+
f"Defined collection table '{collection_name}' with dimension {dimension}"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
'collection_name': collection_name,
|
|
198
|
+
'dimension': dimension,
|
|
199
|
+
'index_type': self.index_type,
|
|
200
|
+
'distance_strategy': self.distance_strategy.value,
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
def _create_faiss_index(self, dimension: int) -> Any:
|
|
204
|
+
"""
|
|
205
|
+
Create a FAISS index based on the configured index type and distance strategy.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
dimension: Dimension of the embedding vectors
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
FAISS index object
|
|
212
|
+
"""
|
|
213
|
+
# Determine the metric type based on distance strategy
|
|
214
|
+
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
215
|
+
# For cosine similarity, we'll normalize vectors and use inner product
|
|
216
|
+
metric = faiss.METRIC_INNER_PRODUCT
|
|
217
|
+
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
218
|
+
metric = faiss.METRIC_L2
|
|
219
|
+
elif self.distance_strategy in [DistanceStrategy.DOT_PRODUCT, DistanceStrategy.MAX_INNER_PRODUCT]:
|
|
220
|
+
metric = faiss.METRIC_INNER_PRODUCT
|
|
221
|
+
else:
|
|
222
|
+
# Default to inner product
|
|
223
|
+
metric = faiss.METRIC_INNER_PRODUCT
|
|
224
|
+
|
|
225
|
+
# Create index based on type
|
|
226
|
+
if self.index_type == "Flat":
|
|
227
|
+
index = faiss.IndexFlatIP(dimension) if metric == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(dimension)
|
|
228
|
+
|
|
229
|
+
elif self.index_type == "IVF":
|
|
230
|
+
# IVF (Inverted File Index) for faster search on large datasets
|
|
231
|
+
quantizer = faiss.IndexFlatIP(dimension) if metric == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(dimension)
|
|
232
|
+
index = faiss.IndexIVFFlat(quantizer, dimension, self.nlist, metric)
|
|
233
|
+
index.nprobe = self.nprobe
|
|
234
|
+
|
|
235
|
+
elif self.index_type == "HNSW":
|
|
236
|
+
# HNSW (Hierarchical Navigable Small World) for very fast search
|
|
237
|
+
index = faiss.IndexHNSWFlat(dimension, self.m, metric)
|
|
238
|
+
index.hnsw.efConstruction = self.ef_construction
|
|
239
|
+
index.hnsw.efSearch = self.ef_search
|
|
240
|
+
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(f"Unsupported index type: {self.index_type}")
|
|
243
|
+
|
|
244
|
+
self.logger.info(
|
|
245
|
+
f"Created FAISS index: type={self.index_type}, "
|
|
246
|
+
f"metric={metric}, dimension={dimension}, cpu_only=True"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return index
|
|
250
|
+
|
|
251
|
+
async def connection(self) -> bool:
|
|
252
|
+
"""
|
|
253
|
+
Establish connection (for compatibility with AbstractStore).
|
|
254
|
+
|
|
255
|
+
Since FAISS is in-memory, this just marks the store as connected.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
True if successful
|
|
259
|
+
"""
|
|
260
|
+
if not self._connected:
|
|
261
|
+
self._connected = True
|
|
262
|
+
self._connection = True # Dummy connection for compatibility
|
|
263
|
+
self.logger.info("FAISSStore connection established (in-memory)")
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
async def disconnect(self) -> None:
|
|
267
|
+
"""
|
|
268
|
+
Disconnect and cleanup resources.
|
|
269
|
+
|
|
270
|
+
Clears all in-memory data.
|
|
271
|
+
"""
|
|
272
|
+
if not self._connected:
|
|
273
|
+
return
|
|
274
|
+
# Clear indexes
|
|
275
|
+
for _, collection in self._collections.items():
|
|
276
|
+
if collection.get('index'):
|
|
277
|
+
del collection['index']
|
|
278
|
+
|
|
279
|
+
# Clear collections
|
|
280
|
+
self._collections.clear()
|
|
281
|
+
|
|
282
|
+
self._connected = False
|
|
283
|
+
self._connection = None
|
|
284
|
+
self.logger.info("FAISSStore disconnected and resources cleared")
|
|
285
|
+
|
|
286
|
+
async def __aenter__(self):
|
|
287
|
+
"""Async context manager entry."""
|
|
288
|
+
if not self._connected:
|
|
289
|
+
await self.connection()
|
|
290
|
+
self._context_depth += 1
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
294
|
+
"""Async context manager exit."""
|
|
295
|
+
# Free embedding resources
|
|
296
|
+
if self._embed_:
|
|
297
|
+
await self._free_resources()
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
# Only disconnect if we're exiting the outermost context
|
|
301
|
+
self._context_depth -= 1
|
|
302
|
+
if self._context_depth <= 0:
|
|
303
|
+
await self.disconnect()
|
|
304
|
+
self._context_depth = 0
|
|
305
|
+
except RuntimeError:
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
async def prepare_embedding_table(
|
|
309
|
+
self,
|
|
310
|
+
collection: str = None,
|
|
311
|
+
dimension: int = None,
|
|
312
|
+
**kwargs
|
|
313
|
+
) -> bool:
|
|
314
|
+
"""
|
|
315
|
+
Prepare the embedding table/collection for storing vectors.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
collection: Collection name
|
|
319
|
+
dimension: Embedding dimension
|
|
320
|
+
**kwargs: Additional arguments
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
True if successful
|
|
324
|
+
"""
|
|
325
|
+
collection = collection or self.collection_name
|
|
326
|
+
dimension = dimension or self.dimension
|
|
327
|
+
|
|
328
|
+
if collection not in self._collections:
|
|
329
|
+
self.define_collection_table(collection, dimension, **kwargs)
|
|
330
|
+
else:
|
|
331
|
+
# Update dimension if needed
|
|
332
|
+
if dimension and self._collections[collection]['dimension'] != dimension:
|
|
333
|
+
self._collections[collection]['dimension'] = dimension
|
|
334
|
+
# Recreate index with new dimension
|
|
335
|
+
self._collections[collection]['index'] = self._create_faiss_index(dimension)
|
|
336
|
+
self.logger.info(f"Updated collection '{collection}' dimension to {dimension}")
|
|
337
|
+
|
|
338
|
+
return True
|
|
339
|
+
|
|
340
|
+
async def create_embedding_table(
|
|
341
|
+
self,
|
|
342
|
+
collection: str = None,
|
|
343
|
+
dimension: int = None,
|
|
344
|
+
**kwargs
|
|
345
|
+
) -> None:
|
|
346
|
+
"""
|
|
347
|
+
Create an embedding table/collection (alias for prepare_embedding_table).
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
collection: Collection name
|
|
351
|
+
dimension: Embedding dimension
|
|
352
|
+
**kwargs: Additional arguments
|
|
353
|
+
"""
|
|
354
|
+
await self.prepare_embedding_table(collection, dimension, **kwargs)
|
|
355
|
+
|
|
356
|
+
async def create_collection(self, collection: str, **kwargs) -> None:
|
|
357
|
+
"""
|
|
358
|
+
Create a new collection.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
collection: Collection name
|
|
362
|
+
**kwargs: Additional arguments (e.g., dimension)
|
|
363
|
+
"""
|
|
364
|
+
dimension = kwargs.get('dimension', self.dimension)
|
|
365
|
+
await self.create_embedding_table(collection, dimension, **kwargs)
|
|
366
|
+
|
|
367
|
+
async def add_documents(
|
|
368
|
+
self,
|
|
369
|
+
documents: List[Document],
|
|
370
|
+
collection: str = None,
|
|
371
|
+
embedding_column: str = None,
|
|
372
|
+
content_column: str = None,
|
|
373
|
+
metadata_column: str = None,
|
|
374
|
+
**kwargs
|
|
375
|
+
) -> None:
|
|
376
|
+
"""
|
|
377
|
+
Add documents to the FAISS store.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
documents: List of Document objects to add
|
|
381
|
+
collection: Collection name (optional, uses default if not provided)
|
|
382
|
+
embedding_column: Name of the embedding column (for compatibility)
|
|
383
|
+
content_column: Name of the content column (for compatibility)
|
|
384
|
+
metadata_column: Name of the metadata column (for compatibility)
|
|
385
|
+
**kwargs: Additional arguments
|
|
386
|
+
"""
|
|
387
|
+
if not self._connected:
|
|
388
|
+
await self.connection()
|
|
389
|
+
|
|
390
|
+
collection = collection or self.collection_name
|
|
391
|
+
|
|
392
|
+
# Ensure collection exists
|
|
393
|
+
if collection not in self._collections:
|
|
394
|
+
self._initialize_collection(collection)
|
|
395
|
+
|
|
396
|
+
collection_data = self._collections[collection]
|
|
397
|
+
|
|
398
|
+
# Extract texts and metadata
|
|
399
|
+
texts = [doc.page_content for doc in documents]
|
|
400
|
+
metadatas = [doc.metadata for doc in documents]
|
|
401
|
+
|
|
402
|
+
# Generate embeddings
|
|
403
|
+
embeddings = self._embed_.embed_documents(texts)
|
|
404
|
+
|
|
405
|
+
# Convert to numpy array
|
|
406
|
+
if isinstance(embeddings, list):
|
|
407
|
+
embeddings = np.array(embeddings, dtype=np.float32)
|
|
408
|
+
elif not isinstance(embeddings, np.ndarray):
|
|
409
|
+
embeddings = np.array(embeddings, dtype=np.float32)
|
|
410
|
+
|
|
411
|
+
# Ensure 2D array
|
|
412
|
+
if embeddings.ndim == 1:
|
|
413
|
+
embeddings = embeddings.reshape(1, -1)
|
|
414
|
+
|
|
415
|
+
# Set dimension if not set
|
|
416
|
+
if collection_data['dimension'] is None:
|
|
417
|
+
collection_data['dimension'] = embeddings.shape[1]
|
|
418
|
+
collection_data['index'] = self._create_faiss_index(embeddings.shape[1])
|
|
419
|
+
|
|
420
|
+
# Normalize embeddings for cosine similarity
|
|
421
|
+
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
422
|
+
faiss.normalize_L2(embeddings)
|
|
423
|
+
|
|
424
|
+
# Train index if needed (for IVF)
|
|
425
|
+
if self.index_type == "IVF" and not collection_data['is_trained']:
|
|
426
|
+
if len(embeddings) >= self.nlist:
|
|
427
|
+
collection_data['index'].train(embeddings)
|
|
428
|
+
collection_data['is_trained'] = True
|
|
429
|
+
self.logger.info(f"Trained IVF index for collection '{collection}'")
|
|
430
|
+
else:
|
|
431
|
+
self.logger.warning(
|
|
432
|
+
f"Not enough vectors to train IVF index "
|
|
433
|
+
f"(need {self.nlist}, got {len(embeddings)}). Using Flat index temporarily."
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Get current index size
|
|
437
|
+
current_idx = collection_data['index'].ntotal
|
|
438
|
+
|
|
439
|
+
# Add to FAISS index
|
|
440
|
+
collection_data['index'].add(embeddings)
|
|
441
|
+
|
|
442
|
+
# Store documents, metadata, and embeddings
|
|
443
|
+
for i, (text, metadata, embedding) in enumerate(zip(texts, metadatas, embeddings)):
|
|
444
|
+
doc_id = str(uuid.uuid4())
|
|
445
|
+
idx = current_idx + i
|
|
446
|
+
|
|
447
|
+
collection_data['documents'][doc_id] = text
|
|
448
|
+
collection_data['metadata'][doc_id] = metadata or {}
|
|
449
|
+
collection_data['embeddings'][doc_id] = embedding
|
|
450
|
+
collection_data['id_to_idx'][doc_id] = idx
|
|
451
|
+
collection_data['idx_to_id'][idx] = doc_id
|
|
452
|
+
|
|
453
|
+
self.logger.info(
|
|
454
|
+
f"✅ Successfully added {len(documents)} documents to collection '{collection}'"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
def get_distance_strategy(
|
|
458
|
+
self,
|
|
459
|
+
query_embedding: np.ndarray,
|
|
460
|
+
metric: str = None
|
|
461
|
+
) -> str:
|
|
462
|
+
"""
|
|
463
|
+
Return the appropriate distance strategy based on the metric or configured strategy.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
query_embedding: Query embedding vector (for compatibility)
|
|
467
|
+
metric: Optional metric string ('COSINE', 'L2', 'IP', 'DOT')
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Distance strategy as string
|
|
471
|
+
"""
|
|
472
|
+
strategy = metric or self.distance_strategy
|
|
473
|
+
|
|
474
|
+
# Convert string metrics to DistanceStrategy enum if needed
|
|
475
|
+
if isinstance(strategy, str):
|
|
476
|
+
metric_mapping = {
|
|
477
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
478
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
479
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
480
|
+
'IP': DistanceStrategy.MAX_INNER_PRODUCT,
|
|
481
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
482
|
+
'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
|
|
483
|
+
'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
|
|
484
|
+
}
|
|
485
|
+
strategy = metric_mapping.get(strategy.upper(), DistanceStrategy.COSINE)
|
|
486
|
+
|
|
487
|
+
return strategy
|
|
488
|
+
|
|
489
|
+
async def similarity_search(
|
|
490
|
+
self,
|
|
491
|
+
query: str,
|
|
492
|
+
collection: str = None,
|
|
493
|
+
k: Optional[int] = None,
|
|
494
|
+
limit: int = None,
|
|
495
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
496
|
+
score_threshold: Optional[float] = None,
|
|
497
|
+
metric: str = None,
|
|
498
|
+
embedding_column: str = None,
|
|
499
|
+
content_column: str = None,
|
|
500
|
+
metadata_column: str = None,
|
|
501
|
+
id_column: str = None,
|
|
502
|
+
**kwargs
|
|
503
|
+
) -> List[SearchResult]:
|
|
504
|
+
"""
|
|
505
|
+
Perform similarity search with optional threshold filtering.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
query: The search query text
|
|
509
|
+
collection: Collection name (optional, uses default if not provided)
|
|
510
|
+
k: Number of results to return (alias for limit)
|
|
511
|
+
limit: Maximum number of results to return
|
|
512
|
+
metadata_filters: Dictionary of metadata filters to apply
|
|
513
|
+
score_threshold: Minimum score threshold (results below threshold filtered out)
|
|
514
|
+
metric: Distance metric to use ('COSINE', 'L2', 'IP')
|
|
515
|
+
embedding_column: Name of the embedding column (for compatibility)
|
|
516
|
+
content_column: Name of the content column (for compatibility)
|
|
517
|
+
metadata_column: Name of the metadata column (for compatibility)
|
|
518
|
+
id_column: Name of the ID column (for compatibility)
|
|
519
|
+
**kwargs: Additional arguments
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
List of SearchResult objects with content, metadata, score, and id
|
|
523
|
+
"""
|
|
524
|
+
if not self._connected:
|
|
525
|
+
await self.connection()
|
|
526
|
+
|
|
527
|
+
collection = collection or self.collection_name
|
|
528
|
+
|
|
529
|
+
if k and not limit:
|
|
530
|
+
limit = k
|
|
531
|
+
if not limit:
|
|
532
|
+
limit = 10
|
|
533
|
+
|
|
534
|
+
# Ensure collection exists
|
|
535
|
+
if collection not in self._collections:
|
|
536
|
+
self.logger.warning(f"Collection '{collection}' not found")
|
|
537
|
+
return []
|
|
538
|
+
|
|
539
|
+
collection_data = self._collections[collection]
|
|
540
|
+
|
|
541
|
+
if collection_data['index'] is None or collection_data['index'].ntotal == 0:
|
|
542
|
+
self.logger.warning(f"Collection '{collection}' is empty")
|
|
543
|
+
return []
|
|
544
|
+
|
|
545
|
+
# Generate query embedding
|
|
546
|
+
query_embedding = self._embed_.embed_query(query)
|
|
547
|
+
|
|
548
|
+
# Convert to numpy array
|
|
549
|
+
if isinstance(query_embedding, list):
|
|
550
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
551
|
+
elif not isinstance(query_embedding, np.ndarray):
|
|
552
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
553
|
+
|
|
554
|
+
# Ensure 2D array
|
|
555
|
+
if query_embedding.ndim == 1:
|
|
556
|
+
query_embedding = query_embedding.reshape(1, -1)
|
|
557
|
+
|
|
558
|
+
# Normalize for cosine similarity
|
|
559
|
+
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
560
|
+
faiss.normalize_L2(query_embedding)
|
|
561
|
+
|
|
562
|
+
# Search FAISS index
|
|
563
|
+
# For metadata filtering, we need to search more results and filter
|
|
564
|
+
search_k = limit * 3 if metadata_filters else limit
|
|
565
|
+
search_k = min(search_k, collection_data['index'].ntotal)
|
|
566
|
+
|
|
567
|
+
distances, indices = collection_data['index'].search(query_embedding, search_k)
|
|
568
|
+
|
|
569
|
+
# Convert distances to scores
|
|
570
|
+
# FAISS returns distances, but we want similarity scores (higher is better)
|
|
571
|
+
distances = distances[0] # Get first row
|
|
572
|
+
indices = indices[0] # Get first row
|
|
573
|
+
|
|
574
|
+
# Convert to scores based on metric
|
|
575
|
+
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
576
|
+
# For cosine with normalized vectors, distance is actually similarity
|
|
577
|
+
scores = distances
|
|
578
|
+
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
579
|
+
# Convert distance to similarity
|
|
580
|
+
scores = 1.0 / (1.0 + distances)
|
|
581
|
+
else:
|
|
582
|
+
# For inner product, distance is already similarity-like
|
|
583
|
+
scores = distances
|
|
584
|
+
|
|
585
|
+
# Build results
|
|
586
|
+
results = []
|
|
587
|
+
for idx, score in zip(indices, scores):
|
|
588
|
+
if idx == -1: # FAISS returns -1 for empty slots
|
|
589
|
+
continue
|
|
590
|
+
|
|
591
|
+
# Get document ID
|
|
592
|
+
doc_id = collection_data['idx_to_id'].get(idx)
|
|
593
|
+
if doc_id is None:
|
|
594
|
+
continue
|
|
595
|
+
|
|
596
|
+
# Get document data
|
|
597
|
+
content = collection_data['documents'].get(doc_id, "")
|
|
598
|
+
metadata = collection_data['metadata'].get(doc_id, {})
|
|
599
|
+
|
|
600
|
+
# Apply metadata filters
|
|
601
|
+
if metadata_filters:
|
|
602
|
+
match = True
|
|
603
|
+
for key, value in metadata_filters.items():
|
|
604
|
+
if metadata.get(key) != value:
|
|
605
|
+
match = False
|
|
606
|
+
break
|
|
607
|
+
if not match:
|
|
608
|
+
continue
|
|
609
|
+
|
|
610
|
+
# Apply score threshold
|
|
611
|
+
if score_threshold is not None:
|
|
612
|
+
if self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
613
|
+
# For L2 distance, lower is better, so invert the check
|
|
614
|
+
if distances[np.where(indices == idx)[0][0]] > score_threshold:
|
|
615
|
+
continue
|
|
616
|
+
else:
|
|
617
|
+
# For cosine and inner product, higher is better
|
|
618
|
+
if score < score_threshold:
|
|
619
|
+
continue
|
|
620
|
+
|
|
621
|
+
# Create search result
|
|
622
|
+
result = SearchResult(
|
|
623
|
+
id=doc_id,
|
|
624
|
+
content=content,
|
|
625
|
+
metadata=metadata,
|
|
626
|
+
score=float(score)
|
|
627
|
+
)
|
|
628
|
+
results.append(result)
|
|
629
|
+
|
|
630
|
+
# Stop if we have enough results
|
|
631
|
+
if len(results) >= limit:
|
|
632
|
+
break
|
|
633
|
+
|
|
634
|
+
self.logger.debug(
|
|
635
|
+
f"Similarity search in collection '{collection}': "
|
|
636
|
+
f"found {len(results)} results (limit={limit})"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return results
|
|
640
|
+
|
|
641
|
+
async def asearch(
|
|
642
|
+
self,
|
|
643
|
+
query: str,
|
|
644
|
+
collection: Optional[str] = None,
|
|
645
|
+
k: Optional[int] = None,
|
|
646
|
+
limit: Optional[int] = None,
|
|
647
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
648
|
+
score_threshold: Optional[float] = None,
|
|
649
|
+
metric: Optional[str] = None,
|
|
650
|
+
embedding_column: Optional[str] = None,
|
|
651
|
+
content_column: Optional[str] = None,
|
|
652
|
+
metadata_column: Optional[str] = None,
|
|
653
|
+
id_column: Optional[str] = None,
|
|
654
|
+
**kwargs,
|
|
655
|
+
) -> List[SearchResult]:
|
|
656
|
+
"""Async alias for :meth:`similarity_search` to match store interface expectations."""
|
|
657
|
+
|
|
658
|
+
return await self.similarity_search(
|
|
659
|
+
query=query,
|
|
660
|
+
collection=collection,
|
|
661
|
+
k=k,
|
|
662
|
+
limit=limit,
|
|
663
|
+
metadata_filters=metadata_filters,
|
|
664
|
+
score_threshold=score_threshold,
|
|
665
|
+
metric=metric,
|
|
666
|
+
embedding_column=embedding_column,
|
|
667
|
+
content_column=content_column,
|
|
668
|
+
metadata_column=metadata_column,
|
|
669
|
+
id_column=id_column,
|
|
670
|
+
**kwargs,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
async def mmr_search(
|
|
674
|
+
self,
|
|
675
|
+
query: str,
|
|
676
|
+
collection: str = None,
|
|
677
|
+
k: int = 4,
|
|
678
|
+
fetch_k: Optional[int] = None,
|
|
679
|
+
lambda_mult: float = 0.5,
|
|
680
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
681
|
+
score_threshold: Optional[float] = None,
|
|
682
|
+
metric: str = None,
|
|
683
|
+
embedding_column: str = None,
|
|
684
|
+
content_column: str = None,
|
|
685
|
+
metadata_column: str = None,
|
|
686
|
+
id_column: str = None,
|
|
687
|
+
**kwargs
|
|
688
|
+
) -> List[SearchResult]:
|
|
689
|
+
"""
|
|
690
|
+
Perform MMR (Maximal Marginal Relevance) search for diversity.
|
|
691
|
+
|
|
692
|
+
MMR balances relevance with diversity by selecting documents that are:
|
|
693
|
+
1. Relevant to the query
|
|
694
|
+
2. Diverse from already selected documents
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
query: The search query text
|
|
698
|
+
collection: Collection name (optional, uses default if not provided)
|
|
699
|
+
k: Number of results to return
|
|
700
|
+
fetch_k: Number of candidates to fetch (default: k*3)
|
|
701
|
+
lambda_mult: Diversity parameter (0-1)
|
|
702
|
+
- 1.0 = maximum relevance (no diversity)
|
|
703
|
+
- 0.0 = maximum diversity (no relevance)
|
|
704
|
+
metadata_filters: Dictionary of metadata filters to apply
|
|
705
|
+
score_threshold: Minimum score threshold
|
|
706
|
+
metric: Distance metric to use ('COSINE', 'L2', 'IP')
|
|
707
|
+
embedding_column: Name of the embedding column (for compatibility)
|
|
708
|
+
content_column: Name of the content column (for compatibility)
|
|
709
|
+
metadata_column: Name of the metadata column (for compatibility)
|
|
710
|
+
id_column: Name of the ID column (for compatibility)
|
|
711
|
+
**kwargs: Additional arguments
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
List of SearchResult objects ordered by MMR score
|
|
715
|
+
"""
|
|
716
|
+
if not self._connected:
|
|
717
|
+
await self.connection()
|
|
718
|
+
|
|
719
|
+
collection = collection or self.collection_name
|
|
720
|
+
|
|
721
|
+
# Default to fetching 3x more candidates than final results
|
|
722
|
+
if fetch_k is None:
|
|
723
|
+
fetch_k = max(k * 3, 20)
|
|
724
|
+
|
|
725
|
+
# Step 1: Get initial candidates using similarity search
|
|
726
|
+
candidates = await self.similarity_search(
|
|
727
|
+
query=query,
|
|
728
|
+
collection=collection,
|
|
729
|
+
limit=fetch_k,
|
|
730
|
+
metadata_filters=metadata_filters,
|
|
731
|
+
score_threshold=score_threshold,
|
|
732
|
+
metric=metric,
|
|
733
|
+
embedding_column=embedding_column,
|
|
734
|
+
content_column=content_column,
|
|
735
|
+
metadata_column=metadata_column,
|
|
736
|
+
id_column=id_column,
|
|
737
|
+
**kwargs
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
if len(candidates) <= k:
|
|
741
|
+
# If we have fewer candidates than requested results, return all
|
|
742
|
+
return candidates
|
|
743
|
+
|
|
744
|
+
# Step 2: Get embeddings for MMR computation
|
|
745
|
+
collection_data = self._collections[collection]
|
|
746
|
+
candidate_embeddings = {}
|
|
747
|
+
for result in candidates:
|
|
748
|
+
embedding = collection_data['embeddings'].get(result.id)
|
|
749
|
+
if embedding is not None:
|
|
750
|
+
candidate_embeddings[result.id] = embedding
|
|
751
|
+
|
|
752
|
+
# Step 3: Get query embedding
|
|
753
|
+
query_embedding = self._embed_.embed_query(query)
|
|
754
|
+
if isinstance(query_embedding, list):
|
|
755
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
756
|
+
elif not isinstance(query_embedding, np.ndarray):
|
|
757
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
758
|
+
|
|
759
|
+
# Step 4: Run MMR algorithm
|
|
760
|
+
selected_results = self._mmr_algorithm(
|
|
761
|
+
query_embedding=query_embedding,
|
|
762
|
+
candidates=candidates,
|
|
763
|
+
candidate_embeddings=candidate_embeddings,
|
|
764
|
+
k=k,
|
|
765
|
+
lambda_mult=lambda_mult,
|
|
766
|
+
metric=metric or self.distance_strategy
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
self.logger.info(
|
|
770
|
+
f"MMR search in collection '{collection}': "
|
|
771
|
+
f"selected {len(selected_results)} results from {len(candidates)} candidates "
|
|
772
|
+
f"(λ={lambda_mult})"
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
return selected_results
|
|
776
|
+
|
|
777
|
+
def _mmr_algorithm(
|
|
778
|
+
self,
|
|
779
|
+
query_embedding: np.ndarray,
|
|
780
|
+
candidates: List[SearchResult],
|
|
781
|
+
candidate_embeddings: Dict[str, np.ndarray],
|
|
782
|
+
k: int,
|
|
783
|
+
lambda_mult: float,
|
|
784
|
+
metric: Union[str, DistanceStrategy]
|
|
785
|
+
) -> List[SearchResult]:
|
|
786
|
+
"""
|
|
787
|
+
Core MMR algorithm implementation (same as PgVectorStore).
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
query_embedding: Query embedding vector
|
|
791
|
+
candidates: List of candidate SearchResult objects
|
|
792
|
+
candidate_embeddings: Dictionary mapping doc ID to embedding vector
|
|
793
|
+
k: Number of results to select
|
|
794
|
+
lambda_mult: MMR diversity parameter (0-1)
|
|
795
|
+
metric: Distance metric to use
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
List of selected SearchResult objects ordered by MMR score
|
|
799
|
+
"""
|
|
800
|
+
if len(candidates) <= k:
|
|
801
|
+
return candidates
|
|
802
|
+
|
|
803
|
+
# Convert query embedding to numpy array
|
|
804
|
+
if not isinstance(query_embedding, np.ndarray):
|
|
805
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
806
|
+
|
|
807
|
+
# Prepare data structures
|
|
808
|
+
selected_indices = []
|
|
809
|
+
remaining_indices = list(range(len(candidates)))
|
|
810
|
+
|
|
811
|
+
# Step 1: Select the most relevant document first
|
|
812
|
+
query_similarities = []
|
|
813
|
+
for candidate in candidates:
|
|
814
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
815
|
+
if doc_embedding is not None:
|
|
816
|
+
similarity = self._compute_similarity(query_embedding, doc_embedding, metric)
|
|
817
|
+
query_similarities.append(similarity)
|
|
818
|
+
else:
|
|
819
|
+
# Fallback to distance score if embedding not available
|
|
820
|
+
query_similarities.append(1.0 / (1.0 + candidate.score))
|
|
821
|
+
|
|
822
|
+
# Select the most similar document first
|
|
823
|
+
best_idx = np.argmax(query_similarities)
|
|
824
|
+
selected_indices.append(best_idx)
|
|
825
|
+
remaining_indices.remove(best_idx)
|
|
826
|
+
|
|
827
|
+
# Step 2: Iteratively select remaining documents using MMR
|
|
828
|
+
for _ in range(min(k - 1, len(remaining_indices))):
|
|
829
|
+
mmr_scores = []
|
|
830
|
+
|
|
831
|
+
for idx in remaining_indices:
|
|
832
|
+
candidate = candidates[idx]
|
|
833
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
834
|
+
|
|
835
|
+
if doc_embedding is None:
|
|
836
|
+
# Fallback scoring if embedding not available
|
|
837
|
+
mmr_score = lambda_mult * query_similarities[idx]
|
|
838
|
+
mmr_scores.append(mmr_score)
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
# Relevance: similarity to query
|
|
842
|
+
relevance = query_similarities[idx]
|
|
843
|
+
|
|
844
|
+
# Diversity: maximum similarity to already selected documents
|
|
845
|
+
max_similarity_to_selected = 0.0
|
|
846
|
+
for selected_idx in selected_indices:
|
|
847
|
+
selected_candidate = candidates[selected_idx]
|
|
848
|
+
selected_embedding = candidate_embeddings.get(selected_candidate.id)
|
|
849
|
+
|
|
850
|
+
if selected_embedding is not None:
|
|
851
|
+
similarity = self._compute_similarity(doc_embedding, selected_embedding, metric)
|
|
852
|
+
max_similarity_to_selected = max(max_similarity_to_selected, similarity)
|
|
853
|
+
|
|
854
|
+
# MMR formula: λ * relevance - (1-λ) * max_similarity_to_selected
|
|
855
|
+
mmr_score = (
|
|
856
|
+
lambda_mult * relevance -
|
|
857
|
+
(1.0 - lambda_mult) * max_similarity_to_selected
|
|
858
|
+
)
|
|
859
|
+
mmr_scores.append(mmr_score)
|
|
860
|
+
|
|
861
|
+
# Select document with highest MMR score
|
|
862
|
+
if mmr_scores:
|
|
863
|
+
best_remaining_idx = np.argmax(mmr_scores)
|
|
864
|
+
best_idx = remaining_indices[best_remaining_idx]
|
|
865
|
+
selected_indices.append(best_idx)
|
|
866
|
+
remaining_indices.remove(best_idx)
|
|
867
|
+
|
|
868
|
+
# Step 3: Return selected results with MMR scores in metadata
|
|
869
|
+
selected_results = []
|
|
870
|
+
for i, idx in enumerate(selected_indices):
|
|
871
|
+
result = candidates[idx]
|
|
872
|
+
# Add MMR ranking to metadata
|
|
873
|
+
enhanced_metadata = dict(result.metadata)
|
|
874
|
+
enhanced_metadata['mmr_rank'] = i + 1
|
|
875
|
+
enhanced_metadata['mmr_lambda'] = lambda_mult
|
|
876
|
+
enhanced_metadata['original_rank'] = idx + 1
|
|
877
|
+
|
|
878
|
+
enhanced_result = SearchResult(
|
|
879
|
+
id=result.id,
|
|
880
|
+
content=result.content,
|
|
881
|
+
metadata=enhanced_metadata,
|
|
882
|
+
score=result.score
|
|
883
|
+
)
|
|
884
|
+
selected_results.append(enhanced_result)
|
|
885
|
+
|
|
886
|
+
return selected_results
|
|
887
|
+
|
|
888
|
+
def _compute_similarity(
|
|
889
|
+
self,
|
|
890
|
+
embedding1: np.ndarray,
|
|
891
|
+
embedding2: np.ndarray,
|
|
892
|
+
metric: Union[str, DistanceStrategy]
|
|
893
|
+
) -> float:
|
|
894
|
+
"""
|
|
895
|
+
Compute similarity between two embeddings based on the specified metric.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
embedding1: First embedding vector (numpy array or list)
|
|
899
|
+
embedding2: Second embedding vector (numpy array or list)
|
|
900
|
+
metric: Distance metric ('COSINE', 'L2', 'IP', etc.)
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
Similarity score (higher = more similar)
|
|
904
|
+
"""
|
|
905
|
+
# Convert to numpy arrays if needed
|
|
906
|
+
if isinstance(embedding1, list):
|
|
907
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
908
|
+
if isinstance(embedding2, list):
|
|
909
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
910
|
+
|
|
911
|
+
# Ensure embeddings are numpy arrays
|
|
912
|
+
if not isinstance(embedding1, np.ndarray):
|
|
913
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
914
|
+
if not isinstance(embedding2, np.ndarray):
|
|
915
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
916
|
+
|
|
917
|
+
# Ensure embeddings are 2D arrays for sklearn
|
|
918
|
+
emb1 = embedding1.reshape(1, -1)
|
|
919
|
+
emb2 = embedding2.reshape(1, -1)
|
|
920
|
+
|
|
921
|
+
# Convert string metrics to DistanceStrategy enum if needed
|
|
922
|
+
if isinstance(metric, str):
|
|
923
|
+
metric_mapping = {
|
|
924
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
925
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
926
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
927
|
+
'IP': DistanceStrategy.MAX_INNER_PRODUCT,
|
|
928
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
929
|
+
'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
|
|
930
|
+
'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
|
|
931
|
+
}
|
|
932
|
+
strategy = metric_mapping.get(metric.upper(), DistanceStrategy.COSINE)
|
|
933
|
+
else:
|
|
934
|
+
strategy = metric
|
|
935
|
+
|
|
936
|
+
if strategy == DistanceStrategy.COSINE:
|
|
937
|
+
# Cosine similarity (returns similarity, not distance)
|
|
938
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
939
|
+
|
|
940
|
+
elif strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
941
|
+
# Convert Euclidean distance to similarity
|
|
942
|
+
distance = euclidean_distances(emb1, emb2)[0, 0]
|
|
943
|
+
similarity = 1.0 / (1.0 + distance)
|
|
944
|
+
|
|
945
|
+
elif strategy in [DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.DOT_PRODUCT]:
|
|
946
|
+
# Dot product (inner product)
|
|
947
|
+
similarity = np.dot(embedding1.flatten(), embedding2.flatten())
|
|
948
|
+
|
|
949
|
+
else:
|
|
950
|
+
# Default to cosine similarity
|
|
951
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
952
|
+
return float(similarity)
|
|
953
|
+
|
|
954
|
+
# Additional methods for compatibility
|
|
955
|
+
|
|
956
|
+
def get_vector(self, metric_type: str = None, **kwargs):
|
|
957
|
+
"""
|
|
958
|
+
Get the FAISS vector store (for compatibility).
|
|
959
|
+
|
|
960
|
+
Args:
|
|
961
|
+
metric_type: Distance metric type
|
|
962
|
+
**kwargs: Additional arguments
|
|
963
|
+
|
|
964
|
+
Returns:
|
|
965
|
+
The FAISSStore instance itself
|
|
966
|
+
"""
|
|
967
|
+
return self
|
|
968
|
+
|
|
969
|
+
async def from_documents(
|
|
970
|
+
self,
|
|
971
|
+
documents: List[Document],
|
|
972
|
+
collection: Union[str, None] = None,
|
|
973
|
+
**kwargs
|
|
974
|
+
) -> 'FAISSStore':
|
|
975
|
+
"""
|
|
976
|
+
Create Vector Store from Documents.
|
|
977
|
+
|
|
978
|
+
Args:
|
|
979
|
+
documents: List of Documents
|
|
980
|
+
collection: Collection Name
|
|
981
|
+
**kwargs: Additional Arguments
|
|
982
|
+
|
|
983
|
+
Returns:
|
|
984
|
+
The FAISSStore instance
|
|
985
|
+
"""
|
|
986
|
+
await self.add_documents(documents, collection=collection, **kwargs)
|
|
987
|
+
return self
|
|
988
|
+
|
|
989
|
+
# Persistence methods
|
|
990
|
+
|
|
991
|
+
def save(self, file_path: Union[str, Path]) -> None:
|
|
992
|
+
"""
|
|
993
|
+
Save the FAISS store to disk.
|
|
994
|
+
|
|
995
|
+
Args:
|
|
996
|
+
file_path: Path to save the store
|
|
997
|
+
"""
|
|
998
|
+
file_path = Path(file_path)
|
|
999
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1000
|
+
|
|
1001
|
+
# Prepare data for saving
|
|
1002
|
+
save_data = {
|
|
1003
|
+
'collections': {},
|
|
1004
|
+
'config': {
|
|
1005
|
+
'collection_name': self.collection_name,
|
|
1006
|
+
'distance_strategy': self.distance_strategy.value,
|
|
1007
|
+
'index_type': self.index_type,
|
|
1008
|
+
'dimension': self.dimension,
|
|
1009
|
+
}
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
# Save each collection
|
|
1013
|
+
for coll_name, coll_data in self._collections.items():
|
|
1014
|
+
# Save FAISS index
|
|
1015
|
+
index_path = file_path.parent / f"{file_path.stem}_{coll_name}.index"
|
|
1016
|
+
if coll_data['index'] is not None:
|
|
1017
|
+
faiss.write_index(coll_data['index'], str(index_path))
|
|
1018
|
+
|
|
1019
|
+
# Save metadata and documents
|
|
1020
|
+
save_data['collections'][coll_name] = {
|
|
1021
|
+
'documents': coll_data['documents'],
|
|
1022
|
+
'metadata': coll_data['metadata'],
|
|
1023
|
+
'embeddings': {k: v.tolist() for k, v in coll_data['embeddings'].items()},
|
|
1024
|
+
'id_to_idx': coll_data['id_to_idx'],
|
|
1025
|
+
'idx_to_id': coll_data['idx_to_id'],
|
|
1026
|
+
'dimension': coll_data['dimension'],
|
|
1027
|
+
'is_trained': coll_data['is_trained'],
|
|
1028
|
+
'index_path': str(index_path),
|
|
1029
|
+
}
|
|
1030
|
+
|
|
1031
|
+
# Save to pickle
|
|
1032
|
+
with open(file_path, 'wb') as f:
|
|
1033
|
+
pickle.dump(save_data, f)
|
|
1034
|
+
|
|
1035
|
+
self.logger.info(f"Saved FAISS store to {file_path}")
|
|
1036
|
+
|
|
1037
|
+
def load(self, file_path: Union[str, Path]) -> None:
|
|
1038
|
+
"""
|
|
1039
|
+
Load the FAISS store from disk.
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
file_path: Path to load the store from
|
|
1043
|
+
"""
|
|
1044
|
+
file_path = Path(file_path)
|
|
1045
|
+
|
|
1046
|
+
with open(file_path, 'rb') as f:
|
|
1047
|
+
save_data = pickle.load(f)
|
|
1048
|
+
|
|
1049
|
+
# Restore config
|
|
1050
|
+
config = save_data.get('config', {})
|
|
1051
|
+
self.collection_name = config.get('collection_name', self.collection_name)
|
|
1052
|
+
|
|
1053
|
+
# Restore collections
|
|
1054
|
+
for coll_name, coll_data in save_data['collections'].items():
|
|
1055
|
+
self._initialize_collection(coll_name)
|
|
1056
|
+
|
|
1057
|
+
# Load FAISS index
|
|
1058
|
+
index_path = coll_data['index_path']
|
|
1059
|
+
if Path(index_path).exists():
|
|
1060
|
+
index = faiss.read_index(index_path)
|
|
1061
|
+
|
|
1062
|
+
self._collections[coll_name]['index'] = index
|
|
1063
|
+
|
|
1064
|
+
# Load metadata and documents
|
|
1065
|
+
self._collections[coll_name]['documents'] = coll_data['documents']
|
|
1066
|
+
self._collections[coll_name]['metadata'] = coll_data['metadata']
|
|
1067
|
+
self._collections[coll_name]['embeddings'] = {
|
|
1068
|
+
k: np.array(v, dtype=np.float32)
|
|
1069
|
+
for k, v in coll_data['embeddings'].items()
|
|
1070
|
+
}
|
|
1071
|
+
self._collections[coll_name]['id_to_idx'] = coll_data['id_to_idx']
|
|
1072
|
+
self._collections[coll_name]['idx_to_id'] = coll_data['idx_to_id']
|
|
1073
|
+
self._collections[coll_name]['dimension'] = coll_data['dimension']
|
|
1074
|
+
self._collections[coll_name]['is_trained'] = coll_data['is_trained']
|
|
1075
|
+
|
|
1076
|
+
self.logger.info(f"Loaded FAISS store from {file_path}")
|
|
1077
|
+
|
|
1078
|
+
def __str__(self) -> str:
|
|
1079
|
+
return f"FAISSStore(collection={self.collection_name}, index_type={self.index_type})"
|
|
1080
|
+
|
|
1081
|
+
def __repr__(self) -> str:
|
|
1082
|
+
return (
|
|
1083
|
+
f"<FAISSStore(collection='{self.collection_name}', "
|
|
1084
|
+
f"index_type='{self.index_type}', "
|
|
1085
|
+
f"distance_strategy='{self.distance_strategy.value}', "
|
|
1086
|
+
"cpu_only=True)>"
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
async def delete_documents(
|
|
1090
|
+
self,
|
|
1091
|
+
document_ids: List[str],
|
|
1092
|
+
collection: str = None,
|
|
1093
|
+
**kwargs
|
|
1094
|
+
) -> None:
|
|
1095
|
+
"""
|
|
1096
|
+
Delete documents by their IDs from the FAISS store.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
document_ids: List of document IDs to delete
|
|
1100
|
+
collection: Collection name (optional, uses default if not provided)
|
|
1101
|
+
**kwargs: Additional arguments
|
|
1102
|
+
"""
|
|
1103
|
+
if not self._connected:
|
|
1104
|
+
await self.connection()
|
|
1105
|
+
|
|
1106
|
+
collection = collection or self.collection_name
|
|
1107
|
+
|
|
1108
|
+
# Ensure collection exists
|
|
1109
|
+
if collection not in self._collections:
|
|
1110
|
+
self.logger.warning(f"Collection '{collection}' not found")
|
|
1111
|
+
return
|
|
1112
|
+
|
|
1113
|
+
collection_data = self._collections[collection]
|
|
1114
|
+
|
|
1115
|
+
for doc_id in document_ids:
|
|
1116
|
+
idx = collection_data['id_to_idx'].get(doc_id)
|
|
1117
|
+
if idx is not None:
|
|
1118
|
+
# Remove from FAISS index
|
|
1119
|
+
# Note: FAISS does not support direct deletion; we mark as deleted
|
|
1120
|
+
# Here we simply ignore the vector in searches by removing mappings
|
|
1121
|
+
del collection_data['documents'][doc_id]
|
|
1122
|
+
del collection_data['metadata'][doc_id]
|
|
1123
|
+
del collection_data['embeddings'][doc_id]
|
|
1124
|
+
del collection_data['id_to_idx'][doc_id]
|
|
1125
|
+
del collection_data['idx_to_id'][idx]
|
|
1126
|
+
|
|
1127
|
+
self.logger.info(
|
|
1128
|
+
f"✅ Successfully deleted {len(document_ids)} documents from collection '{collection}'"
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
async def delete_documents_by_filter(self, filter_func, collection: str = None, **kwargs) -> None:
|
|
1132
|
+
"""
|
|
1133
|
+
Delete documents that match a filter function from the FAISS store.
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
filter_func: A function that takes metadata and returns True if the document should be deleted
|
|
1137
|
+
collection: Collection name (optional, uses default if not provided)
|
|
1138
|
+
**kwargs: Additional arguments
|
|
1139
|
+
"""
|
|
1140
|
+
if not self._connected:
|
|
1141
|
+
await self.connection()
|
|
1142
|
+
|
|
1143
|
+
collection = collection or self.collection_name
|
|
1144
|
+
|
|
1145
|
+
# Ensure collection exists
|
|
1146
|
+
if collection not in self._collections:
|
|
1147
|
+
self.logger.warning(f"Collection '{collection}' not found")
|
|
1148
|
+
return
|
|
1149
|
+
|
|
1150
|
+
collection_data = self._collections[collection]
|
|
1151
|
+
|
|
1152
|
+
to_delete_ids = []
|
|
1153
|
+
for doc_id, metadata in collection_data['metadata'].items():
|
|
1154
|
+
if filter_func(metadata):
|
|
1155
|
+
to_delete_ids.append(doc_id)
|
|
1156
|
+
|
|
1157
|
+
await self.delete_documents(to_delete_ids, collection=collection, **kwargs)
|