ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,1377 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union, Callable
|
|
2
|
+
import uuid
|
|
3
|
+
import time
|
|
4
|
+
import asyncio
|
|
5
|
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
|
6
|
+
from google.cloud import bigquery as bq
|
|
7
|
+
from google.oauth2 import service_account
|
|
8
|
+
from google.cloud.exceptions import NotFound, Conflict
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from navconfig.logging import logging
|
|
12
|
+
from .abstract import AbstractStore
|
|
13
|
+
from ..conf import (
|
|
14
|
+
BIGQUERY_CREDENTIALS,
|
|
15
|
+
BIGQUERY_PROJECT_ID,
|
|
16
|
+
BIGQUERY_DATASET
|
|
17
|
+
)
|
|
18
|
+
from .models import SearchResult, Document, DistanceStrategy
|
|
19
|
+
from .utils.chunking import LateChunkingProcessor
|
|
20
|
+
from ..exceptions import DriverError
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BigQueryStore(AbstractStore):
|
|
25
|
+
"""
|
|
26
|
+
A BigQuery vector store implementation for storing and searching embeddings.
|
|
27
|
+
This store provides vector similarity search capabilities using BigQuery's ML functions.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
table: str = None,
|
|
33
|
+
dataset: str = None,
|
|
34
|
+
project_id: str = None,
|
|
35
|
+
credentials: str = None,
|
|
36
|
+
id_column: str = 'id',
|
|
37
|
+
embedding_column: str = 'embedding',
|
|
38
|
+
document_column: str = 'document',
|
|
39
|
+
text_column: str = 'text',
|
|
40
|
+
metadata_column: str = 'metadata',
|
|
41
|
+
embedding_model: Union[dict, str] = "sentence-transformers/all-mpnet-base-v2",
|
|
42
|
+
embedding: Optional[Callable] = None,
|
|
43
|
+
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
|
44
|
+
auto_initialize: bool = True,
|
|
45
|
+
**kwargs
|
|
46
|
+
):
|
|
47
|
+
"""Initialize the BigQueryStore with the specified parameters."""
|
|
48
|
+
self.table_name = table
|
|
49
|
+
self.dataset = dataset or BIGQUERY_DATASET
|
|
50
|
+
self._project_id = project_id or BIGQUERY_PROJECT_ID
|
|
51
|
+
self._credentials = credentials or BIGQUERY_CREDENTIALS
|
|
52
|
+
|
|
53
|
+
# Column definitions
|
|
54
|
+
self._id_column: str = id_column
|
|
55
|
+
self._embedding_column: str = embedding_column
|
|
56
|
+
self._document_column: str = document_column
|
|
57
|
+
self._text_column: str = text_column
|
|
58
|
+
self._metadata_column: str = metadata_column
|
|
59
|
+
|
|
60
|
+
# Configuration
|
|
61
|
+
self.distance_strategy = distance_strategy
|
|
62
|
+
self._auto_initialize: bool = auto_initialize
|
|
63
|
+
self._collection_store_cache: Dict[str, Any] = {}
|
|
64
|
+
|
|
65
|
+
# Initialize parent class
|
|
66
|
+
super().__init__(
|
|
67
|
+
embedding_model=embedding_model,
|
|
68
|
+
embedding=embedding,
|
|
69
|
+
**kwargs
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# BigQuery client and session management
|
|
73
|
+
self._connection: Optional[bq.Client] = None
|
|
74
|
+
self._connected: bool = False
|
|
75
|
+
self.credentials = None
|
|
76
|
+
self._account = None
|
|
77
|
+
|
|
78
|
+
# Initialize logger
|
|
79
|
+
self.logger = logging.getLogger("BigQueryStore")
|
|
80
|
+
|
|
81
|
+
def get_vector(self, metric_type: str = None, **kwargs):
|
|
82
|
+
raise NotImplementedError("This method is part of the old implementation.")
|
|
83
|
+
|
|
84
|
+
def _execute_query(
|
|
85
|
+
self,
|
|
86
|
+
query: str,
|
|
87
|
+
job_config: Optional[bq.QueryJobConfig] = None
|
|
88
|
+
) -> List[Dict[str, Any]]:
|
|
89
|
+
query_job = self._connection.query(query, job_config=job_config)
|
|
90
|
+
return list(query_job.result())
|
|
91
|
+
|
|
92
|
+
async def _thread_func(self, func, *args, **kwargs):
|
|
93
|
+
"""
|
|
94
|
+
Run a synchronous function in an async context.
|
|
95
|
+
Helper for running blocking calls in a non-blocking way.
|
|
96
|
+
"""
|
|
97
|
+
return await asyncio.to_thread(func, *args, **kwargs)
|
|
98
|
+
|
|
99
|
+
async def connection(self):
|
|
100
|
+
"""Initialize BigQuery client.
|
|
101
|
+
Assuming that authentication is handled outside
|
|
102
|
+
(via environment variables or similar)
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
if self._credentials: # usage of explicit credentials
|
|
106
|
+
self.credentials = service_account.Credentials.from_service_account_file(
|
|
107
|
+
self._credentials
|
|
108
|
+
)
|
|
109
|
+
if not self._project_id:
|
|
110
|
+
self._project_id = self.credentials.project_id
|
|
111
|
+
self._connection = bq.Client(credentials=self.credentials, project=self._project_id)
|
|
112
|
+
self._connected = True
|
|
113
|
+
else:
|
|
114
|
+
self.credentials = self._account
|
|
115
|
+
self._connection = bq.Client(project=self._project_id)
|
|
116
|
+
self._connected = True
|
|
117
|
+
|
|
118
|
+
if self._auto_initialize:
|
|
119
|
+
await self.initialize_database()
|
|
120
|
+
|
|
121
|
+
self.logger.debug("Successfully connected to BigQuery.")
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
self._connected = False
|
|
125
|
+
raise DriverError(f"BigQuery: Error initializing client: {e}")
|
|
126
|
+
return self
|
|
127
|
+
|
|
128
|
+
async def initialize_database(self):
|
|
129
|
+
"""Initialize BigQuery dataset and any required setup."""
|
|
130
|
+
try:
|
|
131
|
+
# Ensure dataset exists
|
|
132
|
+
dataset_id = f"{self._project_id}.{self.dataset}"
|
|
133
|
+
try:
|
|
134
|
+
self._connection.get_dataset(dataset_id)
|
|
135
|
+
self.logger.info(f"Dataset {dataset_id} already exists")
|
|
136
|
+
except NotFound:
|
|
137
|
+
# Create dataset if it doesn't exist
|
|
138
|
+
dataset = bq.Dataset(dataset_id)
|
|
139
|
+
dataset.location = "US"
|
|
140
|
+
dataset = self._connection.create_dataset(dataset, timeout=30)
|
|
141
|
+
self.logger.info(f"Created dataset {dataset_id}")
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
self.logger.warning(
|
|
145
|
+
f"⚠️ Database auto-initialization failed: {e}"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def _define_collection_store(
|
|
149
|
+
self,
|
|
150
|
+
table: str,
|
|
151
|
+
dataset: str,
|
|
152
|
+
dimension: int = 384,
|
|
153
|
+
id_column: str = 'id',
|
|
154
|
+
embedding_column: str = 'embedding',
|
|
155
|
+
document_column: str = 'document',
|
|
156
|
+
metadata_column: str = 'metadata',
|
|
157
|
+
text_column: str = 'text'
|
|
158
|
+
) -> str:
|
|
159
|
+
"""Define a collection store table name for BigQuery.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
str: Fully qualified table name in format project.dataset.table
|
|
163
|
+
"""
|
|
164
|
+
full_table_name = f"{self._project_id}.{dataset}.{table}"
|
|
165
|
+
|
|
166
|
+
if full_table_name in self._collection_store_cache:
|
|
167
|
+
return self._collection_store_cache[full_table_name]
|
|
168
|
+
|
|
169
|
+
# Cache the table reference
|
|
170
|
+
self._collection_store_cache[full_table_name] = {
|
|
171
|
+
'table_name': full_table_name,
|
|
172
|
+
'dimension': dimension,
|
|
173
|
+
'columns': {
|
|
174
|
+
'id': id_column,
|
|
175
|
+
'embedding': embedding_column,
|
|
176
|
+
'document': document_column,
|
|
177
|
+
'metadata': metadata_column,
|
|
178
|
+
'text': text_column
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
self.logger.debug(
|
|
183
|
+
f"Defined collection store: {full_table_name}"
|
|
184
|
+
)
|
|
185
|
+
return full_table_name
|
|
186
|
+
|
|
187
|
+
async def dataset_exists(self, dataset: str = None) -> bool:
|
|
188
|
+
"""Check if a dataset exists in BigQuery."""
|
|
189
|
+
if not self._connected:
|
|
190
|
+
await self.connection()
|
|
191
|
+
|
|
192
|
+
dataset = dataset or self.dataset
|
|
193
|
+
dataset_id = f"{self._project_id}.{dataset}"
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
self._connection.get_dataset(dataset_id)
|
|
197
|
+
return True
|
|
198
|
+
except NotFound:
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
async def create_dataset(self, dataset: str = None, location: str = "US") -> Any:
|
|
202
|
+
"""Create a new dataset in BigQuery."""
|
|
203
|
+
if not self._connected:
|
|
204
|
+
await self.connection()
|
|
205
|
+
|
|
206
|
+
dataset = dataset or self.dataset
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
dataset_ref = bq.DatasetReference(self._project_id, dataset)
|
|
210
|
+
dataset_obj = bq.Dataset(dataset_ref)
|
|
211
|
+
dataset_obj.location = location
|
|
212
|
+
dataset_obj = self._connection.create_dataset(dataset_obj)
|
|
213
|
+
self.logger.debug(f"Created dataset {self._project_id}.{dataset}")
|
|
214
|
+
return dataset_obj
|
|
215
|
+
except Conflict:
|
|
216
|
+
self.logger.warning(f"Dataset {self._project_id}.{dataset} already exists")
|
|
217
|
+
# Get the existing dataset to return it
|
|
218
|
+
return self._connection.get_dataset(f"{self._project_id}.{dataset}")
|
|
219
|
+
except Exception as exc:
|
|
220
|
+
self.logger.error(f"Error creating Dataset: {exc}")
|
|
221
|
+
raise DriverError(
|
|
222
|
+
f"Error creating Dataset: {exc}"
|
|
223
|
+
) from exc
|
|
224
|
+
|
|
225
|
+
async def collection_exists(self, table: str, dataset: str = None) -> bool:
|
|
226
|
+
"""Check if a collection (table) exists in BigQuery."""
|
|
227
|
+
if not self._connected:
|
|
228
|
+
await self.connection()
|
|
229
|
+
|
|
230
|
+
dataset = dataset or self.dataset
|
|
231
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
await self._thread_func(self._connection.get_table, table_id)
|
|
235
|
+
return True
|
|
236
|
+
except NotFound:
|
|
237
|
+
return False
|
|
238
|
+
|
|
239
|
+
async def create_collection(
|
|
240
|
+
self,
|
|
241
|
+
table: str,
|
|
242
|
+
dataset: str = None,
|
|
243
|
+
dimension: int = 768,
|
|
244
|
+
id_column: str = None,
|
|
245
|
+
embedding_column: str = None,
|
|
246
|
+
document_column: str = None,
|
|
247
|
+
metadata_column: str = None,
|
|
248
|
+
**kwargs
|
|
249
|
+
) -> None:
|
|
250
|
+
"""Create a new collection (table) in BigQuery."""
|
|
251
|
+
if not self._connected:
|
|
252
|
+
await self.connection()
|
|
253
|
+
|
|
254
|
+
dataset = dataset or self.dataset
|
|
255
|
+
id_column = id_column or self._id_column
|
|
256
|
+
embedding_column = embedding_column or self._embedding_column
|
|
257
|
+
document_column = document_column or self._document_column
|
|
258
|
+
metadata_column = metadata_column or self._metadata_column
|
|
259
|
+
|
|
260
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
# Check if table already exists
|
|
264
|
+
if await self.collection_exists(table, dataset):
|
|
265
|
+
self.logger.info(f"Collection {table_id} already exists")
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
# Define table schema - use FLOAT64 REPEATED instead of STRUCT
|
|
269
|
+
schema = [
|
|
270
|
+
bq.SchemaField(id_column, "STRING", mode="REQUIRED"),
|
|
271
|
+
bq.SchemaField(embedding_column, "FLOAT64", mode="REPEATED"),
|
|
272
|
+
bq.SchemaField(document_column, "STRING", mode="NULLABLE"),
|
|
273
|
+
bq.SchemaField(metadata_column, "JSON", mode="NULLABLE"),
|
|
274
|
+
bq.SchemaField(self._text_column, "STRING", mode="NULLABLE"),
|
|
275
|
+
bq.SchemaField("collection_id", "STRING", mode="NULLABLE"),
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
table_ref = bq.Table(table_id, schema=schema)
|
|
279
|
+
table_ref = await self._thread_func(self._connection.create_table, table_ref)
|
|
280
|
+
|
|
281
|
+
self.logger.debug(f"Created collection {table_id}")
|
|
282
|
+
|
|
283
|
+
# Cache the collection store
|
|
284
|
+
self._define_collection_store(
|
|
285
|
+
table=table,
|
|
286
|
+
dataset=dataset,
|
|
287
|
+
dimension=dimension,
|
|
288
|
+
id_column=id_column,
|
|
289
|
+
embedding_column=embedding_column,
|
|
290
|
+
document_column=document_column,
|
|
291
|
+
metadata_column=metadata_column
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
self.logger.error(f"Error creating collection: {e}")
|
|
296
|
+
raise RuntimeError(
|
|
297
|
+
f"Failed to create collection: {e}"
|
|
298
|
+
) from e
|
|
299
|
+
|
|
300
|
+
async def drop_collection(self, table: str, dataset: str = None) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Drops the specified table in the given dataset.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
table: The name of the table to drop.
|
|
306
|
+
dataset: The dataset where the table resides (optional, uses default if not provided).
|
|
307
|
+
"""
|
|
308
|
+
if not self._connected:
|
|
309
|
+
await self.connection()
|
|
310
|
+
|
|
311
|
+
dataset = dataset or self.dataset
|
|
312
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Check if table exists first
|
|
316
|
+
if not await self.collection_exists(table, dataset):
|
|
317
|
+
self.logger.warning(f"Table '{table_id}' does not exist, nothing to drop")
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
# Drop the table
|
|
321
|
+
self._connection.delete_table(table_id, not_found_ok=True)
|
|
322
|
+
|
|
323
|
+
# Remove from cache if it exists
|
|
324
|
+
if table_id in self._collection_store_cache:
|
|
325
|
+
del self._collection_store_cache[table_id]
|
|
326
|
+
|
|
327
|
+
self.logger.debug(f"Table '{table_id}' dropped successfully")
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
self.logger.error(
|
|
331
|
+
f"Error dropping table '{table_id}': {e}"
|
|
332
|
+
)
|
|
333
|
+
raise RuntimeError(
|
|
334
|
+
f"Failed to drop table '{table_id}': {e}"
|
|
335
|
+
) from e
|
|
336
|
+
|
|
337
|
+
async def prepare_embedding_table(
|
|
338
|
+
self,
|
|
339
|
+
table: str,
|
|
340
|
+
dataset: str = None,
|
|
341
|
+
dimension: int = 768,
|
|
342
|
+
id_column: str = 'id',
|
|
343
|
+
embedding_column: str = 'embedding',
|
|
344
|
+
document_column: str = 'document',
|
|
345
|
+
metadata_column: str = 'metadata',
|
|
346
|
+
**kwargs
|
|
347
|
+
) -> bool:
|
|
348
|
+
"""Prepare an existing BigQuery table for embedding storage."""
|
|
349
|
+
if not self._connected:
|
|
350
|
+
await self.connection()
|
|
351
|
+
|
|
352
|
+
dataset = dataset or self.dataset
|
|
353
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
# Get existing table
|
|
357
|
+
table_ref = self._connection.get_table(table_id)
|
|
358
|
+
current_schema = table_ref.schema
|
|
359
|
+
|
|
360
|
+
# Check if embedding columns already exist
|
|
361
|
+
existing_fields = {field.name for field in current_schema}
|
|
362
|
+
new_fields = []
|
|
363
|
+
|
|
364
|
+
if embedding_column not in existing_fields:
|
|
365
|
+
new_fields.append(
|
|
366
|
+
bq.SchemaField(
|
|
367
|
+
embedding_column,
|
|
368
|
+
"REPEATED",
|
|
369
|
+
mode="NULLABLE",
|
|
370
|
+
fields=[bq.SchemaField("value", "FLOAT64")]
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
if metadata_column not in existing_fields:
|
|
375
|
+
new_fields.append(
|
|
376
|
+
bq.SchemaField(metadata_column, "JSON", mode="NULLABLE")
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if "collection_id" not in existing_fields:
|
|
380
|
+
new_fields.append(
|
|
381
|
+
bq.SchemaField("collection_id", "STRING", mode="NULLABLE")
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Add new fields if any
|
|
385
|
+
if new_fields:
|
|
386
|
+
new_schema = list(current_schema) + new_fields
|
|
387
|
+
table_ref.schema = new_schema
|
|
388
|
+
table_ref = self._connection.update_table(table_ref, ["schema"])
|
|
389
|
+
self.logger.info(f"Updated table {table_id} schema with embedding columns")
|
|
390
|
+
|
|
391
|
+
# Cache the collection store
|
|
392
|
+
self._define_collection_store(
|
|
393
|
+
table=table,
|
|
394
|
+
dataset=dataset,
|
|
395
|
+
dimension=dimension,
|
|
396
|
+
id_column=id_column,
|
|
397
|
+
embedding_column=embedding_column,
|
|
398
|
+
document_column=document_column,
|
|
399
|
+
metadata_column=metadata_column
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return True
|
|
403
|
+
|
|
404
|
+
except Exception as e:
|
|
405
|
+
self.logger.error(f"Error preparing embedding table: {e}")
|
|
406
|
+
raise RuntimeError(f"Failed to prepare embedding table: {e}") from e
|
|
407
|
+
|
|
408
|
+
async def _wait_for_table_insert_ready(self, table_id: str, max_wait_seconds: int = 30, poll_interval: float = 0.5) -> bool:
|
|
409
|
+
"""
|
|
410
|
+
Wait for a table to be ready for insert operations specifically.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
table_id: Fully qualified table ID (project.dataset.table)
|
|
414
|
+
max_wait_seconds: Maximum time to wait in seconds
|
|
415
|
+
poll_interval: Time between polling attempts in seconds
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
bool: True if table is ready for inserts, False if timeout
|
|
419
|
+
"""
|
|
420
|
+
start_time = time.time()
|
|
421
|
+
|
|
422
|
+
while time.time() - start_time < max_wait_seconds:
|
|
423
|
+
try:
|
|
424
|
+
# First check if table exists via get_table
|
|
425
|
+
table_ref = await self._thread_func(self._connection.get_table, table_id)
|
|
426
|
+
|
|
427
|
+
# Then test with a simple query to see if the table is accessible for operations
|
|
428
|
+
# This is a lightweight way to test table readiness
|
|
429
|
+
test_query = f"SELECT COUNT(*) as row_count FROM `{table_id}` LIMIT 1"
|
|
430
|
+
query_job = await self._thread_func(self._connection.query, test_query)
|
|
431
|
+
await self._thread_func(query_job.result)
|
|
432
|
+
|
|
433
|
+
# If we get here without exception, table is ready for operations
|
|
434
|
+
self.logger.debug(f"Table {table_id} is ready for insert operations")
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
except NotFound:
|
|
438
|
+
# Table not ready yet, wait and retry
|
|
439
|
+
self.logger.debug(f"Table {table_id} not ready yet, waiting {poll_interval}s...")
|
|
440
|
+
await asyncio.sleep(poll_interval)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
# For other errors, we might want to continue trying for a bit
|
|
443
|
+
# as BigQuery can have temporary inconsistencies
|
|
444
|
+
self.logger.debug(f"Checking table readiness, got error (will retry): {e}")
|
|
445
|
+
await asyncio.sleep(poll_interval)
|
|
446
|
+
|
|
447
|
+
self.logger.warning(f"Table {table_id} not ready after {max_wait_seconds} seconds")
|
|
448
|
+
return False
|
|
449
|
+
|
|
450
|
+
async def add_documents(
|
|
451
|
+
self,
|
|
452
|
+
documents: List[Document],
|
|
453
|
+
table: str = None,
|
|
454
|
+
dataset: str = None,
|
|
455
|
+
embedding_column: str = 'embedding',
|
|
456
|
+
content_column: str = 'document',
|
|
457
|
+
metadata_column: str = 'metadata',
|
|
458
|
+
**kwargs
|
|
459
|
+
) -> None:
|
|
460
|
+
"""Add documents to BigQuery table with embeddings."""
|
|
461
|
+
if not self._connected:
|
|
462
|
+
await self.connection()
|
|
463
|
+
|
|
464
|
+
table = table or self.table_name
|
|
465
|
+
dataset = dataset or self.dataset
|
|
466
|
+
|
|
467
|
+
if not table:
|
|
468
|
+
raise ValueError("Table name must be provided")
|
|
469
|
+
|
|
470
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
471
|
+
|
|
472
|
+
# Ensure collection exists
|
|
473
|
+
if not await self.collection_exists(table, dataset):
|
|
474
|
+
await self.create_collection(
|
|
475
|
+
table=table,
|
|
476
|
+
dataset=dataset,
|
|
477
|
+
dimension=self.dimension
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# If we just created the table, wait for it to be ready for inserts
|
|
481
|
+
is_ready = await self._wait_for_table_insert_ready(table_id, max_wait_seconds=60)
|
|
482
|
+
if not is_ready:
|
|
483
|
+
raise RuntimeError(
|
|
484
|
+
f"Table {table_id} was created but not ready for insert operations within timeout"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Process documents
|
|
488
|
+
texts = [doc.page_content for doc in documents]
|
|
489
|
+
# Thread the embedding generation as it can be slow
|
|
490
|
+
embeddings = await self._thread_func(self._embed_.embed_documents, texts)
|
|
491
|
+
metadatas = [doc.metadata for doc in documents]
|
|
492
|
+
|
|
493
|
+
# Prepare data for BigQuery (this is fast, no threading needed)
|
|
494
|
+
rows_to_insert = []
|
|
495
|
+
for i, doc in enumerate(documents):
|
|
496
|
+
embedding_vector = embeddings[i]
|
|
497
|
+
if isinstance(embedding_vector, np.ndarray):
|
|
498
|
+
embedding_vector = embedding_vector.tolist()
|
|
499
|
+
|
|
500
|
+
embedding_array = [float(val) for val in embedding_vector]
|
|
501
|
+
|
|
502
|
+
metadata_value = metadatas[i] or {}
|
|
503
|
+
metadata_json = self._json.dumps(metadata_value) if metadata_value else self._json.dumps({})
|
|
504
|
+
|
|
505
|
+
row = {
|
|
506
|
+
self._id_column: str(uuid.uuid4()),
|
|
507
|
+
embedding_column: embedding_array,
|
|
508
|
+
content_column: texts[i],
|
|
509
|
+
metadata_column: metadata_json,
|
|
510
|
+
"collection_id": str(uuid.uuid4())
|
|
511
|
+
}
|
|
512
|
+
rows_to_insert.append(row)
|
|
513
|
+
|
|
514
|
+
# Insert data with retry logic and longer delays
|
|
515
|
+
max_retries = 5
|
|
516
|
+
retry_delay = 2.0
|
|
517
|
+
|
|
518
|
+
for attempt in range(max_retries):
|
|
519
|
+
try:
|
|
520
|
+
# Always get a fresh table reference
|
|
521
|
+
table_ref = await self._thread_func(self._connection.get_table, table_id)
|
|
522
|
+
|
|
523
|
+
# Add a small delay even on first attempt if table was just created
|
|
524
|
+
if attempt == 0:
|
|
525
|
+
await asyncio.sleep(1.0)
|
|
526
|
+
|
|
527
|
+
errors = await self._thread_func(
|
|
528
|
+
self._connection.insert_rows_json, table_ref, rows_to_insert
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if errors:
|
|
532
|
+
self.logger.error(f"Errors inserting rows: {errors}")
|
|
533
|
+
raise RuntimeError(f"Failed to insert documents: {errors}")
|
|
534
|
+
|
|
535
|
+
self.logger.info(f"Successfully added {len(documents)} documents to {table_id}")
|
|
536
|
+
return # Success, exit the retry loop
|
|
537
|
+
|
|
538
|
+
except NotFound as e:
|
|
539
|
+
if attempt < max_retries - 1:
|
|
540
|
+
self.logger.warning(
|
|
541
|
+
f"Table {table_id} not found for insert (attempt {attempt + 1}/{max_retries}), retrying in {retry_delay}s..."
|
|
542
|
+
)
|
|
543
|
+
await asyncio.sleep(retry_delay)
|
|
544
|
+
retry_delay = min(retry_delay * 1.5, 10.0) # Cap at 10 seconds
|
|
545
|
+
else:
|
|
546
|
+
self.logger.error(
|
|
547
|
+
f"Table {table_id} still not found for insert after {max_retries} attempts"
|
|
548
|
+
)
|
|
549
|
+
raise
|
|
550
|
+
except Exception as e:
|
|
551
|
+
if attempt < max_retries - 1 and "not found" in str(e).lower():
|
|
552
|
+
# Treat any "not found" error as retryable
|
|
553
|
+
self.logger.warning(
|
|
554
|
+
f"Insert failed with 'not found' error (attempt {attempt + 1}/{max_retries}), retrying in {retry_delay}s..."
|
|
555
|
+
)
|
|
556
|
+
await asyncio.sleep(retry_delay)
|
|
557
|
+
retry_delay = min(retry_delay * 1.5, 10.0)
|
|
558
|
+
else:
|
|
559
|
+
self.logger.error(f"Error adding documents: {e}")
|
|
560
|
+
raise
|
|
561
|
+
|
|
562
|
+
def _get_distance_function(self, metric: str = None) -> str:
|
|
563
|
+
"""Get BigQuery ML distance function based on strategy."""
|
|
564
|
+
strategy = metric or self.distance_strategy
|
|
565
|
+
|
|
566
|
+
if isinstance(strategy, str):
|
|
567
|
+
metric_mapping = {
|
|
568
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
569
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
570
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
571
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
572
|
+
}
|
|
573
|
+
strategy = metric_mapping.get(strategy.upper(), DistanceStrategy.COSINE)
|
|
574
|
+
|
|
575
|
+
if strategy == DistanceStrategy.COSINE:
|
|
576
|
+
return "ML.DISTANCE" # Cosine distance in BigQuery ML
|
|
577
|
+
elif strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
578
|
+
return "ML.EUCLIDEAN_DISTANCE"
|
|
579
|
+
elif strategy == DistanceStrategy.DOT_PRODUCT:
|
|
580
|
+
return "ML.DOT_PRODUCT"
|
|
581
|
+
else:
|
|
582
|
+
return "ML.DISTANCE" # Default to cosine
|
|
583
|
+
|
|
584
|
+
async def similarity_search(
|
|
585
|
+
self,
|
|
586
|
+
query: str,
|
|
587
|
+
table: str = None,
|
|
588
|
+
dataset: str = None,
|
|
589
|
+
k: Optional[int] = None,
|
|
590
|
+
limit: int = None,
|
|
591
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
592
|
+
score_threshold: Optional[float] = None,
|
|
593
|
+
metric: str = None,
|
|
594
|
+
embedding_column: str = 'embedding',
|
|
595
|
+
content_column: str = 'document',
|
|
596
|
+
metadata_column: str = 'metadata',
|
|
597
|
+
id_column: str = 'id',
|
|
598
|
+
**kwargs
|
|
599
|
+
) -> List[SearchResult]:
|
|
600
|
+
"""Perform similarity search using BigQuery ML functions."""
|
|
601
|
+
if not self._connected:
|
|
602
|
+
await self.connection()
|
|
603
|
+
|
|
604
|
+
table = table or self.table_name
|
|
605
|
+
dataset = dataset or self.dataset
|
|
606
|
+
|
|
607
|
+
if k and not limit:
|
|
608
|
+
limit = k
|
|
609
|
+
if not limit:
|
|
610
|
+
limit = 10
|
|
611
|
+
|
|
612
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
613
|
+
|
|
614
|
+
# Get query embedding
|
|
615
|
+
# query_embedding = self._embed_.embed_query(query)
|
|
616
|
+
query_embedding = await self._thread_func(self._embed_.embed_query, query)
|
|
617
|
+
if isinstance(query_embedding, np.ndarray):
|
|
618
|
+
query_embedding = query_embedding.tolist()
|
|
619
|
+
|
|
620
|
+
# Convert embedding to BigQuery array literal format
|
|
621
|
+
embedding_literal = "[" + ",".join([str(float(val)) for val in query_embedding]) + "]"
|
|
622
|
+
|
|
623
|
+
# Build the SQL query
|
|
624
|
+
distance_func = self._get_distance_function(metric)
|
|
625
|
+
|
|
626
|
+
# Create the SQL query with embedded array literal
|
|
627
|
+
sql_query = f"""
|
|
628
|
+
SELECT
|
|
629
|
+
{id_column},
|
|
630
|
+
{content_column},
|
|
631
|
+
{metadata_column},
|
|
632
|
+
{distance_func}({embedding_column}, {embedding_literal}) as distance
|
|
633
|
+
FROM `{table_id}`
|
|
634
|
+
WHERE {embedding_column} IS NOT NULL
|
|
635
|
+
"""
|
|
636
|
+
|
|
637
|
+
# Add metadata filters
|
|
638
|
+
filter_params = []
|
|
639
|
+
if metadata_filters:
|
|
640
|
+
filter_conditions = []
|
|
641
|
+
for key, value in metadata_filters.items():
|
|
642
|
+
if isinstance(value, str):
|
|
643
|
+
filter_conditions.append(f"JSON_EXTRACT_SCALAR({metadata_column}, '$.{key}') = @filter_{key}")
|
|
644
|
+
filter_params.append(bq.ScalarQueryParameter(f"filter_{key}", "STRING", value))
|
|
645
|
+
else:
|
|
646
|
+
filter_conditions.append(f"JSON_EXTRACT_SCALAR({metadata_column}, '$.{key}') = @filter_{key}")
|
|
647
|
+
filter_params.append(bq.ScalarQueryParameter(f"filter_{key}", "STRING", str(value)))
|
|
648
|
+
|
|
649
|
+
if filter_conditions:
|
|
650
|
+
sql_query += " AND " + " AND ".join(filter_conditions)
|
|
651
|
+
|
|
652
|
+
# Add score threshold
|
|
653
|
+
if score_threshold is not None:
|
|
654
|
+
sql_query += f" AND {distance_func}({embedding_column}, {embedding_literal}) <= {score_threshold}"
|
|
655
|
+
|
|
656
|
+
# Order and limit
|
|
657
|
+
sql_query += f" ORDER BY distance ASC"
|
|
658
|
+
if limit:
|
|
659
|
+
sql_query += f" LIMIT {limit}"
|
|
660
|
+
|
|
661
|
+
# Configure query parameters
|
|
662
|
+
job_config = None
|
|
663
|
+
if filter_params:
|
|
664
|
+
job_config = bq.QueryJobConfig(
|
|
665
|
+
query_parameters=filter_params
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
try:
|
|
669
|
+
# Execute query
|
|
670
|
+
query_job = await self._thread_func(
|
|
671
|
+
self._connection.query, sql_query, job_config=job_config
|
|
672
|
+
)
|
|
673
|
+
results = await self._thread_func(query_job.result)
|
|
674
|
+
|
|
675
|
+
# Process results
|
|
676
|
+
search_results = []
|
|
677
|
+
for row in results:
|
|
678
|
+
metadata_str = row[metadata_column]
|
|
679
|
+
if isinstance(metadata_str, str):
|
|
680
|
+
# Ensure metadata is a JSON string
|
|
681
|
+
metadata_str = metadata_str.strip()
|
|
682
|
+
metadata = self._json.loads(metadata_str)
|
|
683
|
+
else:
|
|
684
|
+
metadata = dict(metadata_str) if metadata_str else {}
|
|
685
|
+
|
|
686
|
+
search_result = SearchResult(
|
|
687
|
+
id=row[id_column],
|
|
688
|
+
content=row[content_column],
|
|
689
|
+
metadata=metadata,
|
|
690
|
+
score=float(row.distance)
|
|
691
|
+
)
|
|
692
|
+
search_results.append(search_result)
|
|
693
|
+
|
|
694
|
+
self.logger.debug(
|
|
695
|
+
f"Similarity search returned {len(search_results)} results"
|
|
696
|
+
)
|
|
697
|
+
return search_results
|
|
698
|
+
|
|
699
|
+
except Exception as e:
|
|
700
|
+
self.logger.error(f"Error during similarity search: {e}")
|
|
701
|
+
raise
|
|
702
|
+
|
|
703
|
+
async def mmr_search(
|
|
704
|
+
self,
|
|
705
|
+
query: str,
|
|
706
|
+
table: str = None,
|
|
707
|
+
dataset: str = None,
|
|
708
|
+
k: int = 10,
|
|
709
|
+
fetch_k: int = None,
|
|
710
|
+
lambda_mult: float = 0.5,
|
|
711
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
712
|
+
score_threshold: Optional[float] = None,
|
|
713
|
+
metric: str = None,
|
|
714
|
+
embedding_column: str = 'embedding',
|
|
715
|
+
content_column: str = 'document',
|
|
716
|
+
metadata_column: str = 'metadata',
|
|
717
|
+
id_column: str = 'id',
|
|
718
|
+
**kwargs
|
|
719
|
+
) -> List[SearchResult]:
|
|
720
|
+
"""
|
|
721
|
+
Perform Maximal Marginal Relevance (MMR) search.
|
|
722
|
+
|
|
723
|
+
Since BigQuery doesn't have native MMR support, we fetch more candidates
|
|
724
|
+
and perform MMR selection in Python.
|
|
725
|
+
"""
|
|
726
|
+
if not self._connected:
|
|
727
|
+
await self.connection()
|
|
728
|
+
|
|
729
|
+
# Default to fetching 3x more candidates than final results
|
|
730
|
+
if fetch_k is None:
|
|
731
|
+
fetch_k = max(k * 3, 20)
|
|
732
|
+
|
|
733
|
+
# Step 1: Get initial candidates using similarity search
|
|
734
|
+
candidates = await self.similarity_search(
|
|
735
|
+
query=query,
|
|
736
|
+
table=table,
|
|
737
|
+
dataset=dataset,
|
|
738
|
+
limit=fetch_k,
|
|
739
|
+
metadata_filters=metadata_filters,
|
|
740
|
+
score_threshold=score_threshold,
|
|
741
|
+
metric=metric,
|
|
742
|
+
embedding_column=embedding_column,
|
|
743
|
+
content_column=content_column,
|
|
744
|
+
metadata_column=metadata_column,
|
|
745
|
+
id_column=id_column,
|
|
746
|
+
**kwargs
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
if len(candidates) <= k:
|
|
750
|
+
return candidates
|
|
751
|
+
|
|
752
|
+
# Step 2: Fetch embeddings for MMR computation
|
|
753
|
+
candidate_embeddings = await self._fetch_embeddings_for_mmr(
|
|
754
|
+
candidate_ids=[result.id for result in candidates],
|
|
755
|
+
table=table,
|
|
756
|
+
dataset=dataset,
|
|
757
|
+
embedding_column=embedding_column,
|
|
758
|
+
id_column=id_column
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# Step 3: Get query embedding
|
|
762
|
+
query_embedding = self._embed_.embed_query(query)
|
|
763
|
+
|
|
764
|
+
# Step 4: Run MMR algorithm
|
|
765
|
+
selected_results = self._mmr_algorithm(
|
|
766
|
+
query_embedding=query_embedding,
|
|
767
|
+
candidates=candidates,
|
|
768
|
+
candidate_embeddings=candidate_embeddings,
|
|
769
|
+
k=k,
|
|
770
|
+
lambda_mult=lambda_mult,
|
|
771
|
+
metric=metric or self.distance_strategy
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
self.logger.info(
|
|
775
|
+
f"MMR search selected {len(selected_results)} results from {len(candidates)} candidates"
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
return selected_results
|
|
779
|
+
|
|
780
|
+
async def _fetch_embeddings_for_mmr(
|
|
781
|
+
self,
|
|
782
|
+
candidate_ids: List[str],
|
|
783
|
+
table: str,
|
|
784
|
+
dataset: str,
|
|
785
|
+
embedding_column: str,
|
|
786
|
+
id_column: str
|
|
787
|
+
) -> Dict[str, np.ndarray]:
|
|
788
|
+
"""Fetch embedding vectors for candidate documents from BigQuery."""
|
|
789
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
790
|
+
|
|
791
|
+
# Create placeholders for the IDs
|
|
792
|
+
id_placeholders = ', '.join([f"'{id_}'" for id_ in candidate_ids])
|
|
793
|
+
|
|
794
|
+
sql_query = f"""
|
|
795
|
+
SELECT {id_column}, {embedding_column}
|
|
796
|
+
FROM `{table_id}`
|
|
797
|
+
WHERE {id_column} IN ({id_placeholders})
|
|
798
|
+
"""
|
|
799
|
+
|
|
800
|
+
try:
|
|
801
|
+
query_job = await self._thread_func(self._connection.query, sql_query)
|
|
802
|
+
results = await self._thread_func(query_job.result)
|
|
803
|
+
|
|
804
|
+
embeddings_dict = {}
|
|
805
|
+
for row in results:
|
|
806
|
+
doc_id = row[id_column]
|
|
807
|
+
embedding_data = row[embedding_column]
|
|
808
|
+
|
|
809
|
+
# Convert BigQuery array format back to numpy array
|
|
810
|
+
if embedding_data:
|
|
811
|
+
embedding_values = [item['value'] for item in embedding_data]
|
|
812
|
+
embedding_values = embedding_data if isinstance(embedding_data, list) else embedding_data.tolist()
|
|
813
|
+
embeddings_dict[doc_id] = np.array(embedding_values, dtype=np.float32)
|
|
814
|
+
|
|
815
|
+
return embeddings_dict
|
|
816
|
+
|
|
817
|
+
except Exception as e:
|
|
818
|
+
self.logger.error(f"Error fetching embeddings for MMR: {e}")
|
|
819
|
+
raise
|
|
820
|
+
|
|
821
|
+
def _mmr_algorithm(
|
|
822
|
+
self,
|
|
823
|
+
query_embedding: np.ndarray,
|
|
824
|
+
candidates: List[SearchResult],
|
|
825
|
+
candidate_embeddings: Dict[str, np.ndarray],
|
|
826
|
+
k: int,
|
|
827
|
+
lambda_mult: float,
|
|
828
|
+
metric: str
|
|
829
|
+
) -> List[SearchResult]:
|
|
830
|
+
"""Core MMR algorithm implementation (same as PgVectorStore)."""
|
|
831
|
+
if len(candidates) <= k:
|
|
832
|
+
return candidates
|
|
833
|
+
|
|
834
|
+
# Convert query embedding to numpy array
|
|
835
|
+
if not isinstance(query_embedding, np.ndarray):
|
|
836
|
+
query_embedding = np.array(query_embedding, dtype=np.float32)
|
|
837
|
+
|
|
838
|
+
# Prepare data structures
|
|
839
|
+
selected_indices = []
|
|
840
|
+
remaining_indices = list(range(len(candidates)))
|
|
841
|
+
|
|
842
|
+
# Step 1: Select the most relevant document first
|
|
843
|
+
query_similarities = []
|
|
844
|
+
for candidate in candidates:
|
|
845
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
846
|
+
if doc_embedding is not None:
|
|
847
|
+
similarity = self._compute_similarity(query_embedding, doc_embedding, metric)
|
|
848
|
+
query_similarities.append(similarity)
|
|
849
|
+
else:
|
|
850
|
+
# Fallback to distance score if embedding not available
|
|
851
|
+
query_similarities.append(1.0 / (1.0 + candidate.score))
|
|
852
|
+
|
|
853
|
+
# Select the most similar document first
|
|
854
|
+
best_idx = np.argmax(query_similarities)
|
|
855
|
+
selected_indices.append(best_idx)
|
|
856
|
+
remaining_indices.remove(best_idx)
|
|
857
|
+
|
|
858
|
+
# Step 2: Iteratively select remaining documents using MMR
|
|
859
|
+
for _ in range(min(k - 1, len(remaining_indices))):
|
|
860
|
+
mmr_scores = []
|
|
861
|
+
|
|
862
|
+
for idx in remaining_indices:
|
|
863
|
+
candidate = candidates[idx]
|
|
864
|
+
doc_embedding = candidate_embeddings.get(candidate.id)
|
|
865
|
+
|
|
866
|
+
if doc_embedding is None:
|
|
867
|
+
# Fallback scoring if embedding not available
|
|
868
|
+
mmr_score = lambda_mult * query_similarities[idx]
|
|
869
|
+
mmr_scores.append(mmr_score)
|
|
870
|
+
continue
|
|
871
|
+
|
|
872
|
+
# Relevance: similarity to query
|
|
873
|
+
relevance = query_similarities[idx]
|
|
874
|
+
|
|
875
|
+
# Diversity: maximum similarity to already selected documents
|
|
876
|
+
max_similarity_to_selected = 0.0
|
|
877
|
+
for selected_idx in selected_indices:
|
|
878
|
+
selected_candidate = candidates[selected_idx]
|
|
879
|
+
selected_embedding = candidate_embeddings.get(selected_candidate.id)
|
|
880
|
+
|
|
881
|
+
if selected_embedding is not None:
|
|
882
|
+
similarity = self._compute_similarity(doc_embedding, selected_embedding, metric)
|
|
883
|
+
max_similarity_to_selected = max(max_similarity_to_selected, similarity)
|
|
884
|
+
|
|
885
|
+
# MMR formula: λ * relevance - (1-λ) * max_similarity_to_selected
|
|
886
|
+
mmr_score = (
|
|
887
|
+
lambda_mult * relevance -
|
|
888
|
+
(1.0 - lambda_mult) * max_similarity_to_selected
|
|
889
|
+
)
|
|
890
|
+
mmr_scores.append(mmr_score)
|
|
891
|
+
|
|
892
|
+
# Select document with highest MMR score
|
|
893
|
+
if mmr_scores:
|
|
894
|
+
best_remaining_idx = np.argmax(mmr_scores)
|
|
895
|
+
best_idx = remaining_indices[best_remaining_idx]
|
|
896
|
+
selected_indices.append(best_idx)
|
|
897
|
+
remaining_indices.remove(best_idx)
|
|
898
|
+
|
|
899
|
+
# Step 3: Return selected results with MMR scores in metadata
|
|
900
|
+
selected_results = []
|
|
901
|
+
for i, idx in enumerate(selected_indices):
|
|
902
|
+
result = candidates[idx]
|
|
903
|
+
# Add MMR ranking to metadata
|
|
904
|
+
enhanced_metadata = dict(result.metadata)
|
|
905
|
+
enhanced_metadata['mmr_rank'] = i + 1
|
|
906
|
+
enhanced_metadata['mmr_lambda'] = lambda_mult
|
|
907
|
+
enhanced_metadata['original_rank'] = idx + 1
|
|
908
|
+
|
|
909
|
+
enhanced_result = SearchResult(
|
|
910
|
+
id=result.id,
|
|
911
|
+
content=result.content,
|
|
912
|
+
metadata=enhanced_metadata,
|
|
913
|
+
score=result.score
|
|
914
|
+
)
|
|
915
|
+
selected_results.append(enhanced_result)
|
|
916
|
+
|
|
917
|
+
return selected_results
|
|
918
|
+
|
|
919
|
+
def _compute_similarity(
|
|
920
|
+
self,
|
|
921
|
+
embedding1: np.ndarray,
|
|
922
|
+
embedding2: np.ndarray,
|
|
923
|
+
metric: Union[str, Any]
|
|
924
|
+
) -> float:
|
|
925
|
+
"""Compute similarity between two embeddings (same as PgVectorStore)."""
|
|
926
|
+
# Convert to numpy arrays if needed
|
|
927
|
+
if isinstance(embedding1, list):
|
|
928
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
929
|
+
if isinstance(embedding2, list):
|
|
930
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
931
|
+
|
|
932
|
+
# Ensure embeddings are numpy arrays
|
|
933
|
+
if not isinstance(embedding1, np.ndarray):
|
|
934
|
+
embedding1 = np.array(embedding1, dtype=np.float32)
|
|
935
|
+
if not isinstance(embedding2, np.ndarray):
|
|
936
|
+
embedding2 = np.array(embedding2, dtype=np.float32)
|
|
937
|
+
|
|
938
|
+
# Ensure embeddings are 2D arrays for sklearn
|
|
939
|
+
emb1 = embedding1.reshape(1, -1)
|
|
940
|
+
emb2 = embedding2.reshape(1, -1)
|
|
941
|
+
|
|
942
|
+
# Convert string metrics to DistanceStrategy enum if needed
|
|
943
|
+
if isinstance(metric, str):
|
|
944
|
+
metric_mapping = {
|
|
945
|
+
'COSINE': DistanceStrategy.COSINE,
|
|
946
|
+
'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
947
|
+
'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
948
|
+
'DOT': DistanceStrategy.DOT_PRODUCT,
|
|
949
|
+
}
|
|
950
|
+
strategy = metric_mapping.get(metric.upper(), DistanceStrategy.COSINE)
|
|
951
|
+
else:
|
|
952
|
+
strategy = metric
|
|
953
|
+
|
|
954
|
+
if strategy == DistanceStrategy.COSINE:
|
|
955
|
+
# Cosine similarity
|
|
956
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
957
|
+
return float(similarity)
|
|
958
|
+
elif strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
959
|
+
# Convert Euclidean distance to similarity
|
|
960
|
+
distance = euclidean_distances(emb1, emb2)[0, 0]
|
|
961
|
+
similarity = 1.0 / (1.0 + distance)
|
|
962
|
+
return float(similarity)
|
|
963
|
+
elif strategy == DistanceStrategy.DOT_PRODUCT:
|
|
964
|
+
# Dot product
|
|
965
|
+
similarity = np.dot(embedding1.flatten(), embedding2.flatten())
|
|
966
|
+
return float(similarity)
|
|
967
|
+
else:
|
|
968
|
+
# Default to cosine similarity
|
|
969
|
+
similarity = cosine_similarity(emb1, emb2)[0, 0]
|
|
970
|
+
return float(similarity)
|
|
971
|
+
|
|
972
|
+
async def delete_documents(
|
|
973
|
+
self,
|
|
974
|
+
documents: Optional[List[Document]] = None,
|
|
975
|
+
pk: str = 'source_type',
|
|
976
|
+
values: Optional[Union[str, List[str]]] = None,
|
|
977
|
+
table: Optional[str] = None,
|
|
978
|
+
dataset: Optional[str] = None,
|
|
979
|
+
metadata_column: Optional[str] = None,
|
|
980
|
+
**kwargs
|
|
981
|
+
) -> int:
|
|
982
|
+
"""Delete documents from BigQuery table based on metadata field values."""
|
|
983
|
+
if not self._connected:
|
|
984
|
+
await self.connection()
|
|
985
|
+
|
|
986
|
+
table = table or self.table_name
|
|
987
|
+
dataset = dataset or self.dataset
|
|
988
|
+
metadata_column = metadata_column or self._metadata_column
|
|
989
|
+
|
|
990
|
+
if not table:
|
|
991
|
+
raise ValueError("Table name must be provided")
|
|
992
|
+
|
|
993
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
994
|
+
|
|
995
|
+
# Extract values to delete
|
|
996
|
+
delete_values = []
|
|
997
|
+
if values is not None:
|
|
998
|
+
if isinstance(values, str):
|
|
999
|
+
delete_values = [values]
|
|
1000
|
+
else:
|
|
1001
|
+
delete_values = list(values)
|
|
1002
|
+
elif documents:
|
|
1003
|
+
for doc in documents:
|
|
1004
|
+
if hasattr(doc, 'metadata') and doc.metadata and pk in doc.metadata:
|
|
1005
|
+
value = doc.metadata[pk]
|
|
1006
|
+
if value and value not in delete_values:
|
|
1007
|
+
delete_values.append(value)
|
|
1008
|
+
else:
|
|
1009
|
+
raise ValueError("Either 'documents' or 'values' parameter must be provided")
|
|
1010
|
+
|
|
1011
|
+
if not delete_values:
|
|
1012
|
+
self.logger.warning(f"No values found for field '{pk}' to delete")
|
|
1013
|
+
return 0
|
|
1014
|
+
|
|
1015
|
+
deleted_count = 0
|
|
1016
|
+
|
|
1017
|
+
try:
|
|
1018
|
+
for value in delete_values:
|
|
1019
|
+
# Create delete query using JSON extraction
|
|
1020
|
+
delete_query = f"""
|
|
1021
|
+
DELETE FROM `{table_id}`
|
|
1022
|
+
WHERE JSON_EXTRACT_SCALAR({metadata_column}, '$.{pk}') = @value
|
|
1023
|
+
"""
|
|
1024
|
+
|
|
1025
|
+
job_config = bq.QueryJobConfig(
|
|
1026
|
+
query_parameters=[
|
|
1027
|
+
bq.ScalarQueryParameter("value", "STRING", str(value))
|
|
1028
|
+
]
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
query_job = await self._thread_func(
|
|
1032
|
+
self._connection.query,
|
|
1033
|
+
delete_query,
|
|
1034
|
+
job_config=job_config
|
|
1035
|
+
)
|
|
1036
|
+
await self._thread_func(query_job.result)
|
|
1037
|
+
|
|
1038
|
+
rows_deleted = query_job.num_dml_affected_rows or 0
|
|
1039
|
+
deleted_count += rows_deleted
|
|
1040
|
+
|
|
1041
|
+
self.logger.info(
|
|
1042
|
+
f"Deleted {rows_deleted} documents with {pk}='{value}' from {table_id}"
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
self.logger.info(f"Total deleted: {deleted_count} documents")
|
|
1046
|
+
return deleted_count
|
|
1047
|
+
|
|
1048
|
+
except Exception as e:
|
|
1049
|
+
self.logger.error(f"Error deleting documents: {e}")
|
|
1050
|
+
raise RuntimeError(f"Failed to delete documents: {e}") from e
|
|
1051
|
+
|
|
1052
|
+
async def delete_documents_by_filter(
|
|
1053
|
+
self,
|
|
1054
|
+
filter_dict: Dict[str, Union[str, List[str]]],
|
|
1055
|
+
table: Optional[str] = None,
|
|
1056
|
+
dataset: Optional[str] = None,
|
|
1057
|
+
) -> int:
|
|
1058
|
+
"""Deletes documents based on multiple metadata field conditions."""
|
|
1059
|
+
if not self._connected: await self.connection()
|
|
1060
|
+
if not filter_dict: raise ValueError("filter_dict cannot be empty")
|
|
1061
|
+
|
|
1062
|
+
table = table or self.table_name
|
|
1063
|
+
dataset = dataset or self.dataset
|
|
1064
|
+
table_id = f"`{self._project_id}.{dataset}.{table}`"
|
|
1065
|
+
|
|
1066
|
+
where_conditions = []
|
|
1067
|
+
query_params = []
|
|
1068
|
+
for i, (field, values) in enumerate(filter_dict.items()):
|
|
1069
|
+
safe_field = field.replace("'", "\\'")
|
|
1070
|
+
if isinstance(values, (list, tuple)):
|
|
1071
|
+
param_name = f"val_{i}"
|
|
1072
|
+
where_conditions.append(f"JSON_VALUE({self._metadata_column}, '$.{safe_field}') IN UNNEST(@{param_name})")
|
|
1073
|
+
query_params.append(bq.ArrayQueryParameter(param_name, "STRING", [str(v) for v in values]))
|
|
1074
|
+
else:
|
|
1075
|
+
param_name = f"val_{i}"
|
|
1076
|
+
where_conditions.append(f"JSON_VALUE({self._metadata_column}, '$.{safe_field}') = @{param_name}")
|
|
1077
|
+
query_params.append(bq.ScalarQueryParameter(param_name, "STRING", str(values)))
|
|
1078
|
+
|
|
1079
|
+
where_clause = " AND ".join(where_conditions)
|
|
1080
|
+
delete_query = f"DELETE FROM {table_id} WHERE {where_clause}"
|
|
1081
|
+
job_config = bq.QueryJobConfig(query_parameters=query_params)
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
query_job = await self._thread_func(
|
|
1085
|
+
self._connection.query, delete_query, job_config=job_config
|
|
1086
|
+
)
|
|
1087
|
+
await self._thread_func(query_job.result)
|
|
1088
|
+
deleted_count = query_job.num_dml_affected_rows or 0
|
|
1089
|
+
self.logger.debug(
|
|
1090
|
+
f"Deleted {deleted_count} documents from {table_id} with filter: {filter_dict}"
|
|
1091
|
+
)
|
|
1092
|
+
return deleted_count
|
|
1093
|
+
except Exception as e:
|
|
1094
|
+
raise RuntimeError(
|
|
1095
|
+
f"Failed to delete documents by filter: {e}"
|
|
1096
|
+
) from e
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
async def delete_documents_by_ids(
|
|
1100
|
+
self,
|
|
1101
|
+
document_ids: List[str],
|
|
1102
|
+
table: Optional[str] = None,
|
|
1103
|
+
dataset: Optional[str] = None,
|
|
1104
|
+
id_column: Optional[str] = None,
|
|
1105
|
+
**kwargs
|
|
1106
|
+
) -> int:
|
|
1107
|
+
"""Delete documents by their IDs."""
|
|
1108
|
+
if not self._connected:
|
|
1109
|
+
await self.connection()
|
|
1110
|
+
|
|
1111
|
+
if not document_ids:
|
|
1112
|
+
self.logger.warning("No document IDs provided for deletion")
|
|
1113
|
+
return 0
|
|
1114
|
+
|
|
1115
|
+
table = table or self.table_name
|
|
1116
|
+
dataset = dataset or self.dataset
|
|
1117
|
+
id_column = id_column or self._id_column
|
|
1118
|
+
|
|
1119
|
+
if not table:
|
|
1120
|
+
raise ValueError("Table name must be provided")
|
|
1121
|
+
|
|
1122
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
1123
|
+
|
|
1124
|
+
# Build parameterized query for multiple IDs
|
|
1125
|
+
query_parameters = []
|
|
1126
|
+
value_params = []
|
|
1127
|
+
for i, doc_id in enumerate(document_ids):
|
|
1128
|
+
param_name = f"id_{i}"
|
|
1129
|
+
value_params.append(f"@{param_name}")
|
|
1130
|
+
query_parameters.append(
|
|
1131
|
+
bq.ScalarQueryParameter(param_name, "STRING", str(doc_id))
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
delete_query = f"""
|
|
1135
|
+
DELETE FROM `{table_id}`
|
|
1136
|
+
WHERE {id_column} IN ({', '.join(value_params)})
|
|
1137
|
+
"""
|
|
1138
|
+
|
|
1139
|
+
try:
|
|
1140
|
+
job_config = bq.QueryJobConfig(query_parameters=query_parameters)
|
|
1141
|
+
query_job = self._connection.query(delete_query, job_config=job_config)
|
|
1142
|
+
query_job.result() # Wait for completion
|
|
1143
|
+
|
|
1144
|
+
deleted_count = query_job.num_dml_affected_rows or 0
|
|
1145
|
+
|
|
1146
|
+
self.logger.info(
|
|
1147
|
+
f"Deleted {deleted_count} documents by IDs from {table_id}"
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
return deleted_count
|
|
1151
|
+
|
|
1152
|
+
except Exception as e:
|
|
1153
|
+
self.logger.error(f"Error deleting documents by IDs: {e}")
|
|
1154
|
+
raise RuntimeError(
|
|
1155
|
+
f"Failed to delete documents by IDs: {e}"
|
|
1156
|
+
) from e
|
|
1157
|
+
|
|
1158
|
+
async def delete_all_documents(
|
|
1159
|
+
self,
|
|
1160
|
+
table: Optional[str] = None,
|
|
1161
|
+
dataset: Optional[str] = None,
|
|
1162
|
+
confirm: bool = False,
|
|
1163
|
+
**kwargs
|
|
1164
|
+
) -> int:
|
|
1165
|
+
"""Delete ALL documents from the BigQuery table."""
|
|
1166
|
+
if not confirm:
|
|
1167
|
+
raise ValueError(
|
|
1168
|
+
"This operation will delete ALL documents. "
|
|
1169
|
+
"Set confirm=True to proceed."
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
if not self._connected:
|
|
1173
|
+
await self.connection()
|
|
1174
|
+
|
|
1175
|
+
table = table or self.table_name
|
|
1176
|
+
dataset = dataset or self.dataset
|
|
1177
|
+
|
|
1178
|
+
if not table:
|
|
1179
|
+
raise ValueError("Table name must be provided")
|
|
1180
|
+
|
|
1181
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
1182
|
+
|
|
1183
|
+
try:
|
|
1184
|
+
# First count existing documents
|
|
1185
|
+
count_query = f"SELECT COUNT(*) as total FROM `{table_id}`"
|
|
1186
|
+
count_job = self._connection.query(count_query)
|
|
1187
|
+
count_result = list(count_job.result())[0]
|
|
1188
|
+
total_docs = count_result.total
|
|
1189
|
+
|
|
1190
|
+
if total_docs == 0:
|
|
1191
|
+
self.logger.info(f"No documents to delete from {table_id}")
|
|
1192
|
+
return 0
|
|
1193
|
+
|
|
1194
|
+
# Delete all documents
|
|
1195
|
+
delete_query = f"DELETE FROM `{table_id}` WHERE TRUE"
|
|
1196
|
+
query_job = self._connection.query(delete_query)
|
|
1197
|
+
query_job.result() # Wait for completion
|
|
1198
|
+
|
|
1199
|
+
deleted_count = query_job.num_dml_affected_rows or 0
|
|
1200
|
+
|
|
1201
|
+
self.logger.warning(
|
|
1202
|
+
f"DELETED ALL {deleted_count} documents from {table_id}"
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
return deleted_count
|
|
1206
|
+
|
|
1207
|
+
except Exception as e:
|
|
1208
|
+
self.logger.error(f"Error deleting all documents: {e}")
|
|
1209
|
+
raise RuntimeError(f"Failed to delete all documents: {e}") from e
|
|
1210
|
+
|
|
1211
|
+
async def count_documents_by_filter(
|
|
1212
|
+
self,
|
|
1213
|
+
filter_dict: Dict[str, Union[str, List[str]]],
|
|
1214
|
+
table: Optional[str] = None,
|
|
1215
|
+
dataset: Optional[str] = None,
|
|
1216
|
+
metadata_column: Optional[str] = None,
|
|
1217
|
+
**kwargs
|
|
1218
|
+
) -> int:
|
|
1219
|
+
"""Count documents that would be affected by a filter."""
|
|
1220
|
+
if not self._connected:
|
|
1221
|
+
await self.connection()
|
|
1222
|
+
|
|
1223
|
+
if not filter_dict:
|
|
1224
|
+
return 0
|
|
1225
|
+
|
|
1226
|
+
table = table or self.table_name
|
|
1227
|
+
dataset = dataset or self.dataset
|
|
1228
|
+
metadata_column = metadata_column or self._metadata_column
|
|
1229
|
+
|
|
1230
|
+
if not table:
|
|
1231
|
+
raise ValueError("Table name must be provided")
|
|
1232
|
+
|
|
1233
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
1234
|
+
|
|
1235
|
+
# Build WHERE conditions (same logic as delete_documents_by_filter)
|
|
1236
|
+
where_conditions = []
|
|
1237
|
+
query_parameters = []
|
|
1238
|
+
|
|
1239
|
+
for field, values in filter_dict.items():
|
|
1240
|
+
if isinstance(values, (list, tuple)):
|
|
1241
|
+
value_params = []
|
|
1242
|
+
for i, value in enumerate(values):
|
|
1243
|
+
param_name = f"{field}_{i}"
|
|
1244
|
+
value_params.append(f"@{param_name}")
|
|
1245
|
+
query_parameters.append(
|
|
1246
|
+
bq.ScalarQueryParameter(param_name, "STRING", str(value))
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1249
|
+
condition = f"JSON_EXTRACT_SCALAR({metadata_column}, '$.{field}') IN ({', '.join(value_params)})"
|
|
1250
|
+
where_conditions.append(condition)
|
|
1251
|
+
else:
|
|
1252
|
+
param_name = f"{field}_single"
|
|
1253
|
+
where_conditions.append(f"JSON_EXTRACT_SCALAR({metadata_column}, '$.{field}') = @{param_name}")
|
|
1254
|
+
query_parameters.append(
|
|
1255
|
+
bq.ScalarQueryParameter(param_name, "STRING", str(values))
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
where_clause = " AND ".join(where_conditions)
|
|
1259
|
+
count_query = f"""
|
|
1260
|
+
SELECT COUNT(*) as total FROM `{table_id}`
|
|
1261
|
+
WHERE {where_clause}
|
|
1262
|
+
"""
|
|
1263
|
+
|
|
1264
|
+
try:
|
|
1265
|
+
job_config = bq.QueryJobConfig(query_parameters=query_parameters)
|
|
1266
|
+
query_job = self._connection.query(count_query, job_config=job_config)
|
|
1267
|
+
result = list(query_job.result())[0]
|
|
1268
|
+
count = result.total
|
|
1269
|
+
|
|
1270
|
+
self.logger.info(
|
|
1271
|
+
f"Found {count} documents matching filter: {filter_dict}"
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
return count
|
|
1275
|
+
|
|
1276
|
+
except Exception as e:
|
|
1277
|
+
self.logger.error(f"Error counting documents: {e}")
|
|
1278
|
+
raise RuntimeError(f"Failed to count documents: {e}") from e
|
|
1279
|
+
|
|
1280
|
+
async def delete_collection(
|
|
1281
|
+
self,
|
|
1282
|
+
table: str,
|
|
1283
|
+
dataset: str = None
|
|
1284
|
+
) -> None:
|
|
1285
|
+
"""Delete a collection (table) from BigQuery."""
|
|
1286
|
+
if not self._connected:
|
|
1287
|
+
await self.connection()
|
|
1288
|
+
|
|
1289
|
+
dataset = dataset or self.dataset
|
|
1290
|
+
table_id = f"{self._project_id}.{dataset}.{table}"
|
|
1291
|
+
|
|
1292
|
+
if not await self.collection_exists(table, dataset):
|
|
1293
|
+
raise RuntimeError(f"Collection {table_id} does not exist")
|
|
1294
|
+
|
|
1295
|
+
try:
|
|
1296
|
+
self._connection.delete_table(table_id)
|
|
1297
|
+
self.logger.info(f"Collection {table_id} deleted successfully")
|
|
1298
|
+
|
|
1299
|
+
# Remove from cache
|
|
1300
|
+
if table_id in self._collection_store_cache:
|
|
1301
|
+
del self._collection_store_cache[table_id]
|
|
1302
|
+
|
|
1303
|
+
except Exception as e:
|
|
1304
|
+
self.logger.error(f"Error deleting collection: {e}")
|
|
1305
|
+
raise RuntimeError(f"Failed to delete collection: {e}") from e
|
|
1306
|
+
|
|
1307
|
+
async def from_documents(
|
|
1308
|
+
self,
|
|
1309
|
+
documents: List[Document],
|
|
1310
|
+
table: str = None,
|
|
1311
|
+
dataset: str = None,
|
|
1312
|
+
embedding_column: str = 'embedding',
|
|
1313
|
+
content_column: str = 'document',
|
|
1314
|
+
metadata_column: str = 'metadata',
|
|
1315
|
+
chunk_size: int = 8192,
|
|
1316
|
+
chunk_overlap: int = 200,
|
|
1317
|
+
store_full_document: bool = True,
|
|
1318
|
+
**kwargs
|
|
1319
|
+
) -> Dict[str, Any]:
|
|
1320
|
+
"""Add documents using late chunking strategy (if available)."""
|
|
1321
|
+
if not self._connected:
|
|
1322
|
+
await self.connection()
|
|
1323
|
+
|
|
1324
|
+
table = table or self.table_name
|
|
1325
|
+
dataset = dataset or self.dataset
|
|
1326
|
+
|
|
1327
|
+
if not table:
|
|
1328
|
+
raise ValueError("Table name must be provided")
|
|
1329
|
+
|
|
1330
|
+
# For BigQuery, we'll implement a simpler version without late chunking
|
|
1331
|
+
# since LateChunkingProcessor might not be available
|
|
1332
|
+
await self.add_documents(
|
|
1333
|
+
documents=documents,
|
|
1334
|
+
table=table,
|
|
1335
|
+
dataset=dataset,
|
|
1336
|
+
embedding_column=embedding_column,
|
|
1337
|
+
content_column=content_column,
|
|
1338
|
+
metadata_column=metadata_column,
|
|
1339
|
+
**kwargs
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
stats = {
|
|
1343
|
+
'documents_processed': len(documents),
|
|
1344
|
+
'chunks_created': 0, # Not implementing chunking in this version
|
|
1345
|
+
'full_documents_stored': len(documents)
|
|
1346
|
+
}
|
|
1347
|
+
|
|
1348
|
+
return stats
|
|
1349
|
+
|
|
1350
|
+
# Context manager support
|
|
1351
|
+
async def __aenter__(self):
|
|
1352
|
+
"""Context manager entry."""
|
|
1353
|
+
if not self._connected:
|
|
1354
|
+
await self.connection()
|
|
1355
|
+
return self
|
|
1356
|
+
|
|
1357
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
1358
|
+
"""Context manager exit."""
|
|
1359
|
+
# BigQuery client doesn't need explicit cleanup
|
|
1360
|
+
pass
|
|
1361
|
+
|
|
1362
|
+
async def disconnect(self) -> None:
|
|
1363
|
+
"""Disconnect from BigQuery (cleanup resources)."""
|
|
1364
|
+
if self._connection:
|
|
1365
|
+
self._connection.close()
|
|
1366
|
+
self._connection = None
|
|
1367
|
+
self._connected = False
|
|
1368
|
+
self.logger.info("BigQuery client disconnected")
|
|
1369
|
+
|
|
1370
|
+
def __str__(self) -> str:
|
|
1371
|
+
return f"BigQueryStore(project={self._project_id}, dataset={self.dataset})"
|
|
1372
|
+
|
|
1373
|
+
def __repr__(self) -> str:
|
|
1374
|
+
return (
|
|
1375
|
+
f"<BigQueryStore(project='{self._project_id}', "
|
|
1376
|
+
f"dataset='{self.dataset}', table='{self.table_name}')>"
|
|
1377
|
+
)
|