ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
parrot/tools/db.py
ADDED
|
@@ -0,0 +1,1800 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Database Tool for AI-Parrot
|
|
3
|
+
|
|
4
|
+
Consolidates schema extraction, knowledge base building, query generation,
|
|
5
|
+
validation, and execution into a single, powerful database interface.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
from typing import Dict, List, Optional, Any, Union, Literal, Tuple
|
|
9
|
+
import re
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import hashlib
|
|
13
|
+
from datetime import datetime, timedelta, timezone
|
|
14
|
+
from enum import Enum
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from pydantic import (
|
|
17
|
+
BaseModel,
|
|
18
|
+
Field,
|
|
19
|
+
field_validator,
|
|
20
|
+
model_validator
|
|
21
|
+
)
|
|
22
|
+
from asyncdb import AsyncDB
|
|
23
|
+
from .abstract import (
|
|
24
|
+
AbstractTool,
|
|
25
|
+
ToolResult,
|
|
26
|
+
AbstractToolArgsSchema
|
|
27
|
+
)
|
|
28
|
+
from ..stores.abstract import AbstractStore
|
|
29
|
+
from ..clients.base import AbstractClient
|
|
30
|
+
from ..clients.factory import LLMFactory
|
|
31
|
+
from ..models import AIMessage
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DatabaseFlavor(str, Enum):
|
|
35
|
+
"""Supported database flavors."""
|
|
36
|
+
POSTGRESQL = "postgresql"
|
|
37
|
+
MYSQL = "mysql"
|
|
38
|
+
SQLSERVER = "sqlserver"
|
|
39
|
+
BIGQUERY = "bigquery"
|
|
40
|
+
INFLUXDB = "influxdb"
|
|
41
|
+
CASSANDRA = "cassandra"
|
|
42
|
+
MONGODB = "mongodb"
|
|
43
|
+
ELASTICSEARCH = "elasticsearch"
|
|
44
|
+
SQLITE = "sqlite"
|
|
45
|
+
DUCKDB = "duckdb"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class QueryType(str, Enum):
|
|
49
|
+
"""Supported query types."""
|
|
50
|
+
SELECT = "SELECT"
|
|
51
|
+
INSERT = "INSERT"
|
|
52
|
+
UPDATE = "UPDATE"
|
|
53
|
+
DELETE = "DELETE"
|
|
54
|
+
CREATE = "CREATE"
|
|
55
|
+
ALTER = "ALTER"
|
|
56
|
+
DROP = "DROP"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class OutputFormat(str, Enum):
|
|
60
|
+
"""Supported output formats."""
|
|
61
|
+
PANDAS = "pandas"
|
|
62
|
+
JSON = "json"
|
|
63
|
+
DICT = "dict"
|
|
64
|
+
CSV = "csv"
|
|
65
|
+
STRUCTURED = "structured" # Uses Pydantic models
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SchemaMetadata(BaseModel):
|
|
69
|
+
"""Metadata for a database schema."""
|
|
70
|
+
schema_name: str
|
|
71
|
+
tables: List[Dict[str, Any]]
|
|
72
|
+
views: List[Dict[str, Any]]
|
|
73
|
+
functions: List[Dict[str, Any]]
|
|
74
|
+
procedures: List[Dict[str, Any]]
|
|
75
|
+
indexes: List[Dict[str, Any]]
|
|
76
|
+
constraints: List[Dict[str, Any]]
|
|
77
|
+
last_updated: datetime
|
|
78
|
+
database_flavor: DatabaseFlavor
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class QueryValidationResult(BaseModel):
|
|
82
|
+
"""Result of query validation."""
|
|
83
|
+
is_valid: bool
|
|
84
|
+
query_type: Optional[QueryType]
|
|
85
|
+
affected_tables: List[str]
|
|
86
|
+
estimated_cost: Optional[float]
|
|
87
|
+
warnings: List[str]
|
|
88
|
+
errors: List[str]
|
|
89
|
+
security_checks: Dict[str, bool]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class DatabaseToolArgs(AbstractToolArgsSchema):
|
|
93
|
+
"""Arguments for the unified database tool."""
|
|
94
|
+
|
|
95
|
+
# Query specification
|
|
96
|
+
natural_language_query: Optional[str] = Field(
|
|
97
|
+
None, description="Natural language description of what you want to query"
|
|
98
|
+
)
|
|
99
|
+
sql_query: Optional[str] = Field(
|
|
100
|
+
None, description="Direct SQL query to execute"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Database connection
|
|
104
|
+
database_flavor: DatabaseFlavor = Field(
|
|
105
|
+
DatabaseFlavor.POSTGRESQL, description="Type of database to connect to"
|
|
106
|
+
)
|
|
107
|
+
connection_params: Optional[Dict[str, Any]] = Field(
|
|
108
|
+
None, description="Database connection parameters"
|
|
109
|
+
)
|
|
110
|
+
schema_names: List[str] = Field(
|
|
111
|
+
default=["public"], description="Schema names to work with"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Operation modes
|
|
115
|
+
operation: Literal[
|
|
116
|
+
"schema_extract", "query_generate", "query_validate",
|
|
117
|
+
"query_execute", "full_pipeline", "explain_query"
|
|
118
|
+
] = Field(
|
|
119
|
+
"full_pipeline", description="What operation to perform"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Query options
|
|
123
|
+
max_rows: int = Field(1000, description="Maximum rows to return")
|
|
124
|
+
timeout_seconds: int = Field(300, description="Query timeout")
|
|
125
|
+
dry_run: bool = Field(False, description="Validate without executing")
|
|
126
|
+
|
|
127
|
+
# Output options
|
|
128
|
+
output_format: OutputFormat = Field(
|
|
129
|
+
OutputFormat.PANDAS, description="Format for query results"
|
|
130
|
+
)
|
|
131
|
+
structured_output_schema: Optional[Dict[str, Any]] = Field(
|
|
132
|
+
None, description="Pydantic schema for structured outputs"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Knowledge base options
|
|
136
|
+
update_knowledge_base: bool = Field(
|
|
137
|
+
True, description="Whether to update schema knowledge base"
|
|
138
|
+
)
|
|
139
|
+
cache_duration_hours: int = Field(
|
|
140
|
+
24, description="How long to cache schema metadata"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@model_validator(mode='after')
|
|
144
|
+
def validate_query_input(self) -> 'DatabaseToolArgs':
|
|
145
|
+
# Ensure at least one query type is provided for query operations
|
|
146
|
+
if self.operation in ['query_generate', 'query_execute', 'full_pipeline', 'explain_query']:
|
|
147
|
+
if not self.natural_language_query and not self.sql_query:
|
|
148
|
+
raise ValueError("Either natural_language_query or sql_query must be provided")
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class DatabaseTool(AbstractTool):
|
|
153
|
+
"""
|
|
154
|
+
Unified Database Tool that handles the complete database interaction pipeline:
|
|
155
|
+
|
|
156
|
+
1. Schema Discovery: Extract and cache table schemas from any supported database
|
|
157
|
+
2. Knowledge Base Building: Store schema metadata in vector store for RAG
|
|
158
|
+
3. Query Generation: Convert natural language to database-specific queries
|
|
159
|
+
4. Query Validation: Syntax checking, security validation, cost estimation
|
|
160
|
+
5. Query Execution: Safe execution with proper error handling
|
|
161
|
+
6. Structured Output: Format results according to specified schemas
|
|
162
|
+
|
|
163
|
+
This tool consolidates the functionality of SchemaTool, DatabaseQueryTool,
|
|
164
|
+
and SQLAgent into a single, cohesive interface.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
name = "database_tool"
|
|
168
|
+
description = """Unified database tool for schema discovery, query generation,
|
|
169
|
+
validation, and execution across multiple database types"""
|
|
170
|
+
args_schema = DatabaseToolArgs
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
knowledge_store: Optional[AbstractStore] = None,
|
|
175
|
+
default_connection_params: Optional[Dict[DatabaseFlavor, Dict]] = None,
|
|
176
|
+
enable_query_caching: bool = True,
|
|
177
|
+
llm: Optional[Union[AbstractClient, str]] = None,
|
|
178
|
+
**kwargs
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Initialize the unified database tool.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
knowledge_store: Vector store for schema metadata and RAG
|
|
185
|
+
default_connection_params: Default connection parameters per database type
|
|
186
|
+
enable_query_caching: Whether to cache query results
|
|
187
|
+
llm: LLM to use for query generation and validation
|
|
188
|
+
"""
|
|
189
|
+
super().__init__(**kwargs)
|
|
190
|
+
|
|
191
|
+
self.knowledge_store = knowledge_store
|
|
192
|
+
self.default_connection_params = default_connection_params or {}
|
|
193
|
+
self.enable_query_caching = enable_query_caching
|
|
194
|
+
|
|
195
|
+
# Initialize LLM
|
|
196
|
+
if isinstance(llm, str):
|
|
197
|
+
self.llm = LLMFactory.create(llm)
|
|
198
|
+
else:
|
|
199
|
+
self.llm = llm
|
|
200
|
+
|
|
201
|
+
# Cache for schema metadata and database connections
|
|
202
|
+
self._schema_cache: Dict[str, Tuple[SchemaMetadata, datetime]] = {}
|
|
203
|
+
self._connection_cache: Dict[str, AsyncDB] = {}
|
|
204
|
+
|
|
205
|
+
# Database-specific query generators and validators
|
|
206
|
+
self._query_generators = {}
|
|
207
|
+
self._query_validators = {}
|
|
208
|
+
|
|
209
|
+
self._setup_database_handlers()
|
|
210
|
+
|
|
211
|
+
def _setup_database_handlers(self):
|
|
212
|
+
"""Initialize database-specific handlers for different flavors."""
|
|
213
|
+
# This would be expanded to include handlers for each database type
|
|
214
|
+
self._query_generators = {
|
|
215
|
+
DatabaseFlavor.POSTGRESQL: self._generate_postgresql_query,
|
|
216
|
+
DatabaseFlavor.MYSQL: self._generate_mysql_query,
|
|
217
|
+
DatabaseFlavor.BIGQUERY: self._generate_bigquery_query,
|
|
218
|
+
# Add more database-specific generators...
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
self._query_validators = {
|
|
222
|
+
DatabaseFlavor.POSTGRESQL: self._validate_postgresql_query,
|
|
223
|
+
DatabaseFlavor.MYSQL: self._validate_mysql_query,
|
|
224
|
+
DatabaseFlavor.BIGQUERY: self._validate_bigquery_query,
|
|
225
|
+
# Add more database-specific validators...
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
def _clean_sql(self, sql_query: str) -> str:
|
|
229
|
+
"""Clean SQL query from markdown formatting."""
|
|
230
|
+
if not sql_query:
|
|
231
|
+
return ""
|
|
232
|
+
# Remove markdown code blocks
|
|
233
|
+
clean_query = re.sub(r'```\w*\n?', '', sql_query)
|
|
234
|
+
clean_query = clean_query.replace('```', '')
|
|
235
|
+
return clean_query.strip()
|
|
236
|
+
|
|
237
|
+
async def _execute(
|
|
238
|
+
self,
|
|
239
|
+
natural_language_query: Optional[str] = None,
|
|
240
|
+
sql_query: Optional[str] = None,
|
|
241
|
+
database_flavor: DatabaseFlavor = DatabaseFlavor.POSTGRESQL,
|
|
242
|
+
connection_params: Optional[Dict[str, Any]] = None,
|
|
243
|
+
schema_names: List[str] = ["public"],
|
|
244
|
+
operation: str = "full_pipeline",
|
|
245
|
+
max_rows: int = 1000,
|
|
246
|
+
timeout_seconds: int = 300,
|
|
247
|
+
dry_run: bool = False,
|
|
248
|
+
output_format: OutputFormat = OutputFormat.PANDAS,
|
|
249
|
+
structured_output_schema: Optional[Dict[str, Any]] = None,
|
|
250
|
+
update_knowledge_base: bool = True,
|
|
251
|
+
cache_duration_hours: int = 24,
|
|
252
|
+
**kwargs
|
|
253
|
+
) -> ToolResult:
|
|
254
|
+
"""
|
|
255
|
+
Execute the unified database tool pipeline.
|
|
256
|
+
|
|
257
|
+
The method routes to different sub-operations based on the operation parameter,
|
|
258
|
+
or executes the full pipeline for complete query processing.
|
|
259
|
+
"""
|
|
260
|
+
try:
|
|
261
|
+
# Fallback to default connection parameters if not provided
|
|
262
|
+
if connection_params is None:
|
|
263
|
+
connection_params = self.default_connection_params.get(database_flavor)
|
|
264
|
+
|
|
265
|
+
if sql_query:
|
|
266
|
+
sql_query = self._clean_sql(sql_query)
|
|
267
|
+
|
|
268
|
+
# Route to specific operations
|
|
269
|
+
if operation == "schema_extract":
|
|
270
|
+
return await self._extract_schema_operation(
|
|
271
|
+
database_flavor, connection_params, schema_names,
|
|
272
|
+
update_knowledge_base, cache_duration_hours
|
|
273
|
+
)
|
|
274
|
+
if operation == "query_generate":
|
|
275
|
+
return await self._query_generation_operation(
|
|
276
|
+
natural_language_query, database_flavor, connection_params, schema_names
|
|
277
|
+
)
|
|
278
|
+
if operation == "query_validate":
|
|
279
|
+
return await self._query_validation_operation(
|
|
280
|
+
sql_query or natural_language_query, database_flavor, connection_params
|
|
281
|
+
)
|
|
282
|
+
if operation == "query_execute":
|
|
283
|
+
return await self._query_execution_operation(
|
|
284
|
+
sql_query, database_flavor, connection_params,
|
|
285
|
+
max_rows, timeout_seconds, output_format, structured_output_schema
|
|
286
|
+
)
|
|
287
|
+
if operation == "full_pipeline":
|
|
288
|
+
return await self._full_pipeline_operation(
|
|
289
|
+
natural_language_query, sql_query, database_flavor, connection_params,
|
|
290
|
+
schema_names, max_rows, timeout_seconds, dry_run,
|
|
291
|
+
output_format, structured_output_schema, update_knowledge_base, cache_duration_hours
|
|
292
|
+
)
|
|
293
|
+
if operation == "explain_query":
|
|
294
|
+
return await self._explain_query_operation(
|
|
295
|
+
sql_query or natural_language_query, database_flavor, connection_params
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
299
|
+
|
|
300
|
+
except Exception as e:
|
|
301
|
+
return ToolResult(
|
|
302
|
+
status="error",
|
|
303
|
+
result=None,
|
|
304
|
+
error=f"Database tool execution failed: {str(e)}",
|
|
305
|
+
metadata={
|
|
306
|
+
"operation": operation,
|
|
307
|
+
"database_flavor": database_flavor.value,
|
|
308
|
+
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
309
|
+
}
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
async def _full_pipeline_operation(
|
|
313
|
+
self,
|
|
314
|
+
natural_language_query: Optional[str],
|
|
315
|
+
sql_query: Optional[str],
|
|
316
|
+
database_flavor: DatabaseFlavor,
|
|
317
|
+
connection_params: Optional[Dict[str, Any]],
|
|
318
|
+
schema_names: List[str],
|
|
319
|
+
max_rows: int,
|
|
320
|
+
timeout_seconds: int,
|
|
321
|
+
dry_run: bool,
|
|
322
|
+
output_format: OutputFormat,
|
|
323
|
+
structured_output_schema: Optional[Dict[str, Any]],
|
|
324
|
+
update_knowledge_base: bool,
|
|
325
|
+
cache_duration_hours: int
|
|
326
|
+
) -> ToolResult:
|
|
327
|
+
"""
|
|
328
|
+
Execute the complete database interaction pipeline.
|
|
329
|
+
|
|
330
|
+
This is the main orchestrator method that combines all functionality:
|
|
331
|
+
schema extraction, knowledge base updates, query generation, validation, and execution.
|
|
332
|
+
"""
|
|
333
|
+
pipeline_results = {
|
|
334
|
+
"schema_extraction": None,
|
|
335
|
+
"query_generation": None,
|
|
336
|
+
"query_validation": None,
|
|
337
|
+
"query_execution": None,
|
|
338
|
+
"knowledge_base_update": None
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
# Step 1: Extract and cache schema metadata
|
|
343
|
+
self.logger.info(f"Step 1: Extracting schema for {database_flavor.value}")
|
|
344
|
+
schema_result = await self._extract_schema_operation(
|
|
345
|
+
database_flavor, connection_params, schema_names,
|
|
346
|
+
update_knowledge_base, cache_duration_hours
|
|
347
|
+
)
|
|
348
|
+
pipeline_results["schema_extraction"] = schema_result.result
|
|
349
|
+
|
|
350
|
+
# Step 2: Generate SQL query if natural language was provided
|
|
351
|
+
generated_query = sql_query
|
|
352
|
+
if natural_language_query:
|
|
353
|
+
self.logger.info("Step 2: Generating SQL from natural language")
|
|
354
|
+
query_result = await self._query_generation_operation(
|
|
355
|
+
natural_language_query, database_flavor, connection_params, schema_names
|
|
356
|
+
)
|
|
357
|
+
pipeline_results["query_generation"] = query_result.result
|
|
358
|
+
generated_query = query_result.result.get("sql_query")
|
|
359
|
+
|
|
360
|
+
if not generated_query:
|
|
361
|
+
raise ValueError("No valid SQL query to execute")
|
|
362
|
+
|
|
363
|
+
# Step 3: Validate the query
|
|
364
|
+
self.logger.info("Step 3: Validating SQL query")
|
|
365
|
+
validation_result = await self._query_validation_operation(
|
|
366
|
+
generated_query, database_flavor, connection_params
|
|
367
|
+
)
|
|
368
|
+
pipeline_results["query_validation"] = validation_result.result
|
|
369
|
+
|
|
370
|
+
if not validation_result.result["is_valid"]:
|
|
371
|
+
if dry_run:
|
|
372
|
+
return ToolResult(
|
|
373
|
+
status="success",
|
|
374
|
+
result={
|
|
375
|
+
"pipeline_results": pipeline_results,
|
|
376
|
+
"dry_run": True,
|
|
377
|
+
"query_valid": False
|
|
378
|
+
},
|
|
379
|
+
metadata={"operation": "full_pipeline", "dry_run": True}
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
raise ValueError(f"Query validation failed: {validation_result.result['errors']}")
|
|
383
|
+
|
|
384
|
+
# Step 4: Execute the query (unless dry run)
|
|
385
|
+
if not dry_run:
|
|
386
|
+
self.logger.info("Step 4: Executing validated query")
|
|
387
|
+
execution_result = await self._query_execution_operation(
|
|
388
|
+
generated_query, database_flavor, connection_params,
|
|
389
|
+
max_rows, timeout_seconds, output_format, structured_output_schema
|
|
390
|
+
)
|
|
391
|
+
pipeline_results["query_execution"] = execution_result.result
|
|
392
|
+
|
|
393
|
+
# Success! Return comprehensive results
|
|
394
|
+
return ToolResult(
|
|
395
|
+
status="success",
|
|
396
|
+
result={
|
|
397
|
+
"pipeline_results": pipeline_results,
|
|
398
|
+
"final_query": generated_query,
|
|
399
|
+
"dry_run": dry_run,
|
|
400
|
+
"execution_summary": {
|
|
401
|
+
"rows_returned": len(pipeline_results["query_execution"]["data"]) if not dry_run and pipeline_results["query_execution"] else 0,
|
|
402
|
+
"execution_time_seconds": pipeline_results["query_execution"]["execution_time"] if not dry_run and pipeline_results["query_execution"] else None,
|
|
403
|
+
"output_format": output_format.value
|
|
404
|
+
}
|
|
405
|
+
},
|
|
406
|
+
metadata={
|
|
407
|
+
"operation": "full_pipeline",
|
|
408
|
+
"database_flavor": database_flavor.value,
|
|
409
|
+
"schema_count": len(schema_names),
|
|
410
|
+
"natural_language_input": natural_language_query is not None,
|
|
411
|
+
"timestamp": datetime.utcnow().isoformat()
|
|
412
|
+
}
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
except Exception as e:
|
|
416
|
+
return ToolResult(
|
|
417
|
+
status="error",
|
|
418
|
+
result={"pipeline_results": pipeline_results},
|
|
419
|
+
error=f"Pipeline failed at step: {str(e)}",
|
|
420
|
+
metadata={"operation": "full_pipeline", "partial_results": True}
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
async def _extract_schema_operation(
|
|
424
|
+
self,
|
|
425
|
+
database_flavor: DatabaseFlavor,
|
|
426
|
+
connection_params: Optional[Dict[str, Any]],
|
|
427
|
+
schema_names: List[str],
|
|
428
|
+
update_knowledge_base: bool,
|
|
429
|
+
cache_duration_hours: int
|
|
430
|
+
) -> ToolResult:
|
|
431
|
+
"""Extract database schema metadata and optionally update knowledge base."""
|
|
432
|
+
try:
|
|
433
|
+
# Check cache first
|
|
434
|
+
cache_key = self._generate_schema_cache_key(database_flavor, connection_params, schema_names)
|
|
435
|
+
cached_schema, cache_time = self._schema_cache.get(cache_key, (None, None))
|
|
436
|
+
|
|
437
|
+
if cached_schema and cache_time:
|
|
438
|
+
cache_age = datetime.utcnow() - cache_time
|
|
439
|
+
if cache_age < timedelta(hours=cache_duration_hours):
|
|
440
|
+
self.logger.info(f"Using cached schema metadata (age: {cache_age})")
|
|
441
|
+
return ToolResult(
|
|
442
|
+
status="success",
|
|
443
|
+
result=cached_schema.dict(),
|
|
444
|
+
metadata={"source": "cache", "cache_age_hours": cache_age.total_seconds() / 3600}
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Extract fresh schema metadata
|
|
448
|
+
db_connection = await self._get_database_connection(database_flavor, connection_params)
|
|
449
|
+
schema_metadata = await self._extract_database_schema(db_connection, database_flavor, schema_names)
|
|
450
|
+
|
|
451
|
+
# Cache the results
|
|
452
|
+
self._schema_cache[cache_key] = (schema_metadata, datetime.utcnow())
|
|
453
|
+
|
|
454
|
+
# Update knowledge base if requested
|
|
455
|
+
if update_knowledge_base and self.knowledge_store:
|
|
456
|
+
await self._update_schema_knowledge_base(schema_metadata)
|
|
457
|
+
|
|
458
|
+
return ToolResult(
|
|
459
|
+
status="success",
|
|
460
|
+
result=schema_metadata.dict(),
|
|
461
|
+
metadata={
|
|
462
|
+
"source": "database",
|
|
463
|
+
"schema_count": len(schema_names),
|
|
464
|
+
"table_count": len(schema_metadata.tables),
|
|
465
|
+
"view_count": len(schema_metadata.views),
|
|
466
|
+
"knowledge_base_updated": update_knowledge_base and self.knowledge_store is not None
|
|
467
|
+
}
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
except Exception as e:
|
|
471
|
+
return ToolResult(
|
|
472
|
+
status="error",
|
|
473
|
+
result=None,
|
|
474
|
+
error=f"Schema extraction failed: {str(e)}",
|
|
475
|
+
metadata={"operation": "schema_extract"}
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Additional helper methods would continue here...
|
|
479
|
+
# Including _query_generation_operation, _query_validation_operation,
|
|
480
|
+
# _query_execution_operation, and all the database-specific implementations
|
|
481
|
+
|
|
482
|
+
def _generate_schema_cache_key(
|
|
483
|
+
self,
|
|
484
|
+
database_flavor: DatabaseFlavor,
|
|
485
|
+
connection_params: Optional[Dict[str, Any]],
|
|
486
|
+
schema_names: List[str]
|
|
487
|
+
) -> str:
|
|
488
|
+
"""Generate a unique cache key for schema metadata."""
|
|
489
|
+
key_data = {
|
|
490
|
+
"flavor": database_flavor.value,
|
|
491
|
+
"params": connection_params or {},
|
|
492
|
+
"schemas": sorted(schema_names)
|
|
493
|
+
}
|
|
494
|
+
return hashlib.md5(json.dumps(key_data, sort_keys=True).encode()).hexdigest()
|
|
495
|
+
|
|
496
|
+
async def _get_database_connection(
|
|
497
|
+
self,
|
|
498
|
+
database_flavor: DatabaseFlavor,
|
|
499
|
+
connection_params: Optional[Dict[str, Any]]
|
|
500
|
+
) -> AsyncDB:
|
|
501
|
+
"""Get or create a database connection using AsyncDB."""
|
|
502
|
+
"""Get or create a database connection using AsyncDB."""
|
|
503
|
+
# Normalize connection parameters
|
|
504
|
+
params = connection_params.copy() if connection_params else {}
|
|
505
|
+
|
|
506
|
+
# Common mapping: username -> user (used by asyncpg and others)
|
|
507
|
+
if 'username' in params and 'user' not in params:
|
|
508
|
+
params['user'] = params.pop('username')
|
|
509
|
+
|
|
510
|
+
driver_map = {
|
|
511
|
+
DatabaseFlavor.POSTGRESQL: 'pg',
|
|
512
|
+
DatabaseFlavor.MYSQL: 'mysql',
|
|
513
|
+
DatabaseFlavor.SQLITE: 'sqlite',
|
|
514
|
+
}
|
|
515
|
+
driver = driver_map.get(database_flavor, database_flavor.value)
|
|
516
|
+
return AsyncDB(driver, params=params)
|
|
517
|
+
|
|
518
|
+
async def _extract_database_schema(
|
|
519
|
+
self,
|
|
520
|
+
db_connection: AsyncDB,
|
|
521
|
+
database_flavor: DatabaseFlavor,
|
|
522
|
+
schema_names: List[str]
|
|
523
|
+
) -> SchemaMetadata:
|
|
524
|
+
"""Extract comprehensive schema metadata from the database."""
|
|
525
|
+
"""Extract comprehensive schema metadata from the database."""
|
|
526
|
+
if database_flavor == DatabaseFlavor.POSTGRESQL:
|
|
527
|
+
return await self._extract_postgresql_schema(db_connection, schema_names)
|
|
528
|
+
|
|
529
|
+
raise NotImplementedError(f"Schema extraction not implemented for {database_flavor}")
|
|
530
|
+
|
|
531
|
+
async def _extract_postgresql_schema(
|
|
532
|
+
self,
|
|
533
|
+
db: AsyncDB,
|
|
534
|
+
schema_names: List[str]
|
|
535
|
+
) -> SchemaMetadata:
|
|
536
|
+
"""Extract schema for PostgreSQL."""
|
|
537
|
+
tables_data = []
|
|
538
|
+
async with await db.connection() as conn:
|
|
539
|
+
schemas_list = ", ".join([f"'{s}'" for s in schema_names])
|
|
540
|
+
if not schemas_list:
|
|
541
|
+
schemas_list = "'public'" # Default
|
|
542
|
+
|
|
543
|
+
query = f"""
|
|
544
|
+
SELECT t.table_schema, t.table_name, c.column_name, c.data_type
|
|
545
|
+
FROM information_schema.tables t
|
|
546
|
+
JOIN information_schema.columns c
|
|
547
|
+
ON t.table_schema = c.table_schema AND t.table_name = c.table_name
|
|
548
|
+
WHERE t.table_schema IN ({schemas_list})
|
|
549
|
+
ORDER BY t.table_schema, t.table_name, c.ordinal_position
|
|
550
|
+
"""
|
|
551
|
+
try:
|
|
552
|
+
rows = await conn.fetch(query) # Using fetch if available, or query
|
|
553
|
+
except Exception:
|
|
554
|
+
# Fallback to query if fetch not available on conn wrapper
|
|
555
|
+
rows = await conn.query(query)
|
|
556
|
+
|
|
557
|
+
# Check if rows is a list of lists (result set wrapper)
|
|
558
|
+
if rows and isinstance(rows, list) and len(rows) > 0 and isinstance(rows[0], list):
|
|
559
|
+
rows = rows[0]
|
|
560
|
+
|
|
561
|
+
# Process rows
|
|
562
|
+
grouped = {}
|
|
563
|
+
for row in rows:
|
|
564
|
+
# Handle possible dict or object access
|
|
565
|
+
# asyncpg.Record supports .get() and ['key']
|
|
566
|
+
if hasattr(row, 'get'):
|
|
567
|
+
s_name = row.get('table_schema')
|
|
568
|
+
t_name = row.get('table_name')
|
|
569
|
+
c_name = row.get('column_name')
|
|
570
|
+
d_type = row.get('data_type')
|
|
571
|
+
elif isinstance(row, (list, tuple)) and len(row) >= 4:
|
|
572
|
+
s_name = row[0]
|
|
573
|
+
t_name = row[1]
|
|
574
|
+
c_name = row[2]
|
|
575
|
+
d_type = row[3]
|
|
576
|
+
else:
|
|
577
|
+
# Attempt dict access as fallback
|
|
578
|
+
try:
|
|
579
|
+
s_name = row['table_schema']
|
|
580
|
+
t_name = row['table_name']
|
|
581
|
+
c_name = row['column_name']
|
|
582
|
+
d_type = row['data_type']
|
|
583
|
+
except (TypeError, KeyError, IndexError):
|
|
584
|
+
continue # Skip invalid rows
|
|
585
|
+
|
|
586
|
+
k = (s_name, t_name)
|
|
587
|
+
if k not in grouped:
|
|
588
|
+
grouped[k] = {
|
|
589
|
+
"schema": s_name,
|
|
590
|
+
"name": t_name,
|
|
591
|
+
"columns": []
|
|
592
|
+
}
|
|
593
|
+
grouped[k]["columns"].append({"name": c_name, "type": d_type})
|
|
594
|
+
|
|
595
|
+
tables_data = list(grouped.values())
|
|
596
|
+
|
|
597
|
+
return SchemaMetadata(
|
|
598
|
+
schema_name=",".join(schema_names),
|
|
599
|
+
tables=tables_data,
|
|
600
|
+
views=[],
|
|
601
|
+
functions=[],
|
|
602
|
+
procedures=[],
|
|
603
|
+
indexes=[],
|
|
604
|
+
constraints=[],
|
|
605
|
+
last_updated=datetime.utcnow(),
|
|
606
|
+
database_flavor=DatabaseFlavor.POSTGRESQL
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
async def _query_generation_operation(
|
|
610
|
+
self,
|
|
611
|
+
natural_language_query: str,
|
|
612
|
+
database_flavor: DatabaseFlavor,
|
|
613
|
+
connection_params: Optional[Dict[str, Any]],
|
|
614
|
+
schema_names: List[str]
|
|
615
|
+
) -> ToolResult:
|
|
616
|
+
"""Generate SQL query from natural language using schema context."""
|
|
617
|
+
try:
|
|
618
|
+
# Get schema context for query generation
|
|
619
|
+
schema_key = self._generate_schema_cache_key(database_flavor, connection_params, schema_names)
|
|
620
|
+
cached_schema, _ = self._schema_cache.get(schema_key, (None, None))
|
|
621
|
+
|
|
622
|
+
if not cached_schema:
|
|
623
|
+
# If no cached schema, extract it first
|
|
624
|
+
schema_result = await self._extract_schema_operation(
|
|
625
|
+
database_flavor, connection_params, schema_names, False, 24
|
|
626
|
+
)
|
|
627
|
+
if schema_result.status != "success" or not schema_result.result:
|
|
628
|
+
raise ValueError(f"Schema extraction failed: {schema_result.error or 'No result returned'}")
|
|
629
|
+
|
|
630
|
+
cached_schema = SchemaMetadata(**schema_result.result)
|
|
631
|
+
|
|
632
|
+
# Use database-specific query generator
|
|
633
|
+
generator = self._query_generators.get(database_flavor)
|
|
634
|
+
if not generator:
|
|
635
|
+
raise ValueError(f"No query generator available for {database_flavor.value}")
|
|
636
|
+
|
|
637
|
+
# Build rich context for LLM query generation
|
|
638
|
+
schema_context = self._build_schema_context_for_llm(cached_schema, natural_language_query)
|
|
639
|
+
|
|
640
|
+
# Generate the SQL query
|
|
641
|
+
generated_sql = await generator(natural_language_query, schema_context)
|
|
642
|
+
generated_sql = self._clean_sql(generated_sql)
|
|
643
|
+
|
|
644
|
+
return ToolResult(
|
|
645
|
+
status="success",
|
|
646
|
+
result={
|
|
647
|
+
"natural_language_query": natural_language_query,
|
|
648
|
+
"sql_query": generated_sql,
|
|
649
|
+
"database_flavor": database_flavor.value,
|
|
650
|
+
"schema_context_used": len(schema_context.get("relevant_tables", [])),
|
|
651
|
+
"generation_timestamp": datetime.utcnow().isoformat()
|
|
652
|
+
},
|
|
653
|
+
metadata={
|
|
654
|
+
"operation": "query_generation",
|
|
655
|
+
"has_schema_context": bool(schema_context)
|
|
656
|
+
}
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
except Exception as e:
|
|
660
|
+
return ToolResult(
|
|
661
|
+
status="error",
|
|
662
|
+
result=None,
|
|
663
|
+
error=f"Query generation failed: {str(e)}",
|
|
664
|
+
metadata={"operation": "query_generation"}
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
async def _query_validation_operation(
|
|
668
|
+
self,
|
|
669
|
+
sql_query: str,
|
|
670
|
+
database_flavor: DatabaseFlavor,
|
|
671
|
+
connection_params: Optional[Dict[str, Any]]
|
|
672
|
+
) -> ToolResult:
|
|
673
|
+
"""Validate SQL query for syntax, security, and performance."""
|
|
674
|
+
try:
|
|
675
|
+
validator = self._query_validators.get(database_flavor)
|
|
676
|
+
if not validator:
|
|
677
|
+
raise ValueError(f"No query validator available for {database_flavor.value}")
|
|
678
|
+
|
|
679
|
+
validation_result = await validator(sql_query)
|
|
680
|
+
|
|
681
|
+
return ToolResult(
|
|
682
|
+
status="success" if validation_result.is_valid else "warning",
|
|
683
|
+
result=validation_result.dict(),
|
|
684
|
+
metadata={
|
|
685
|
+
"operation": "query_validation",
|
|
686
|
+
"query_type": validation_result.query_type.value if validation_result.query_type else None
|
|
687
|
+
}
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
except Exception as e:
|
|
691
|
+
return ToolResult(
|
|
692
|
+
status="error",
|
|
693
|
+
result=None,
|
|
694
|
+
error=f"Query validation failed: {str(e)}",
|
|
695
|
+
metadata={"operation": "query_validation"}
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
async def _query_execution_operation(
|
|
699
|
+
self,
|
|
700
|
+
sql_query: str,
|
|
701
|
+
database_flavor: DatabaseFlavor,
|
|
702
|
+
connection_params: Optional[Dict[str, Any]],
|
|
703
|
+
max_rows: int,
|
|
704
|
+
timeout_seconds: int,
|
|
705
|
+
output_format: OutputFormat,
|
|
706
|
+
structured_output_schema: Optional[Dict[str, Any]]
|
|
707
|
+
) -> ToolResult:
|
|
708
|
+
"""Execute SQL query and format results according to specifications."""
|
|
709
|
+
try:
|
|
710
|
+
db_connection = await self._get_database_connection(database_flavor, connection_params)
|
|
711
|
+
|
|
712
|
+
# Execute query with timeout and row limit
|
|
713
|
+
start_time = datetime.utcnow()
|
|
714
|
+
|
|
715
|
+
# This integrates your existing DatabaseQueryTool logic
|
|
716
|
+
raw_results = await self._execute_query_with_asyncdb(
|
|
717
|
+
db_connection, sql_query, max_rows, timeout_seconds
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
execution_time = (datetime.utcnow() - start_time).total_seconds()
|
|
721
|
+
|
|
722
|
+
# Format results according to specified output format
|
|
723
|
+
formatted_results = await self._format_query_results(
|
|
724
|
+
raw_results, output_format, structured_output_schema
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
return ToolResult(
|
|
728
|
+
status="success",
|
|
729
|
+
result={
|
|
730
|
+
"data": formatted_results,
|
|
731
|
+
"row_count": len(raw_results) if isinstance(raw_results, list) else None,
|
|
732
|
+
"execution_time": execution_time,
|
|
733
|
+
"output_format": output_format.value,
|
|
734
|
+
"query": sql_query
|
|
735
|
+
},
|
|
736
|
+
metadata={
|
|
737
|
+
"operation": "query_execution",
|
|
738
|
+
"database_flavor": database_flavor.value,
|
|
739
|
+
"rows_returned": len(raw_results) if isinstance(raw_results, list) else 0
|
|
740
|
+
}
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
except Exception as e:
|
|
744
|
+
return ToolResult(
|
|
745
|
+
status="error",
|
|
746
|
+
result=None,
|
|
747
|
+
error=f"Query execution failed: {str(e)}",
|
|
748
|
+
metadata={"operation": "query_execution", "query": sql_query}
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
async def _explain_query_operation(
|
|
752
|
+
self,
|
|
753
|
+
sql_query: str,
|
|
754
|
+
database_flavor: DatabaseFlavor,
|
|
755
|
+
connection_params: Optional[Dict[str, Any]]
|
|
756
|
+
) -> ToolResult:
|
|
757
|
+
"""
|
|
758
|
+
Explain query execution plan and provide LLM-based optimizations.
|
|
759
|
+
"""
|
|
760
|
+
if not sql_query:
|
|
761
|
+
return ToolResult(
|
|
762
|
+
status="error",
|
|
763
|
+
result=None,
|
|
764
|
+
error="No SQL query provided for explanation",
|
|
765
|
+
metadata={"operation": "explain_query"}
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
try:
|
|
769
|
+
db_connection = await self._get_database_connection(database_flavor, connection_params)
|
|
770
|
+
|
|
771
|
+
# Determine appropriate EXPLAIN command
|
|
772
|
+
explain_cmd = f"EXPLAIN ANALYZE {sql_query}"
|
|
773
|
+
if database_flavor == DatabaseFlavor.MYSQL:
|
|
774
|
+
# MySQL 8.0.18+ supports EXPLAIN ANALYZE, otherwise fallback to EXPLAIN
|
|
775
|
+
# For safety/compatibility we might start with EXPLAIN if ANALYZE fails or just try
|
|
776
|
+
explain_cmd = f"EXPLAIN ANALYZE {sql_query}"
|
|
777
|
+
elif database_flavor == DatabaseFlavor.BIGQUERY:
|
|
778
|
+
# BigQuery doesn't support EXPLAIN ANALYZE syntax directly in this way usually
|
|
779
|
+
# It returns stats in job metadata.
|
|
780
|
+
# But we can try to use Dry Run or similar.
|
|
781
|
+
# For now, let's assume standard SQL syntax applies or let execution fail and fallback
|
|
782
|
+
pass
|
|
783
|
+
|
|
784
|
+
# Execute explanation
|
|
785
|
+
try:
|
|
786
|
+
raw_plan = await self._execute_query_with_asyncdb(
|
|
787
|
+
db_connection, explain_cmd, max_rows=0, timeout_seconds=30
|
|
788
|
+
)
|
|
789
|
+
except Exception as e:
|
|
790
|
+
# Fallback to simple EXPLAIN if ANALYZE fails (e.g. not supported or timeouts)
|
|
791
|
+
self.logger.warning(f"EXPLAIN ANALYZE failed, falling back to EXPLAIN: {e}")
|
|
792
|
+
explain_cmd = f"EXPLAIN {sql_query}"
|
|
793
|
+
raw_plan = await self._execute_query_with_asyncdb(
|
|
794
|
+
db_connection, explain_cmd, max_rows=0, timeout_seconds=30
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# Format plan into string
|
|
798
|
+
plan_text = ""
|
|
799
|
+
if isinstance(raw_plan, list):
|
|
800
|
+
# Flatten the list of rows/dicts
|
|
801
|
+
for row in raw_plan:
|
|
802
|
+
if isinstance(row, dict):
|
|
803
|
+
# Usually the first column contains the plan output
|
|
804
|
+
plan_text += list(row.values())[0] + "\n"
|
|
805
|
+
elif isinstance(row, (list, tuple)):
|
|
806
|
+
plan_text += str(row[0]) + "\n"
|
|
807
|
+
else:
|
|
808
|
+
plan_text += str(row) + "\n"
|
|
809
|
+
else:
|
|
810
|
+
plan_text = str(raw_plan)
|
|
811
|
+
|
|
812
|
+
# Ask LLM to explain and optimize
|
|
813
|
+
llm_explanation = "No LLM configured for explanation."
|
|
814
|
+
if self.llm:
|
|
815
|
+
prompt = (
|
|
816
|
+
f"You are a database performance expert. Analyze the following query plan for a {database_flavor.value} database.\n"
|
|
817
|
+
f"Query:\n```sql\n{sql_query}\n```\n\n"
|
|
818
|
+
f"Execution Plan:\n```\n{plan_text}\n```\n\n"
|
|
819
|
+
"Please provide:\n"
|
|
820
|
+
"1. A human-readable explanation of how the query is executed.\n"
|
|
821
|
+
"2. Performance bottlenecks identified in the plan.\n"
|
|
822
|
+
"3. Concrete suggestions for indexes or query rewrites to improve performance.\n"
|
|
823
|
+
"4. Rating of current query efficiency (1-10)."
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
response = await self.llm.ask(prompt)
|
|
827
|
+
if isinstance(response, AIMessage):
|
|
828
|
+
llm_explanation = str(response.output).strip()
|
|
829
|
+
elif isinstance(response, dict) and 'content' in response:
|
|
830
|
+
llm_explanation = str(response['content']).strip()
|
|
831
|
+
else:
|
|
832
|
+
llm_explanation = str(response).strip()
|
|
833
|
+
|
|
834
|
+
return ToolResult(
|
|
835
|
+
status="success",
|
|
836
|
+
result={
|
|
837
|
+
"query": sql_query,
|
|
838
|
+
"plan": plan_text,
|
|
839
|
+
"analysis": llm_explanation,
|
|
840
|
+
"database_flavor": database_flavor.value
|
|
841
|
+
},
|
|
842
|
+
metadata={
|
|
843
|
+
"operation": "explain_query",
|
|
844
|
+
"command_used": explain_cmd
|
|
845
|
+
}
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
except Exception as e:
|
|
849
|
+
return ToolResult(
|
|
850
|
+
status="error",
|
|
851
|
+
result=None,
|
|
852
|
+
error=f"Query explanation failed: {str(e)}",
|
|
853
|
+
metadata={"operation": "explain_query"}
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
def _build_schema_context_for_llm(
|
|
857
|
+
self,
|
|
858
|
+
schema_metadata: SchemaMetadata,
|
|
859
|
+
natural_language_query: str
|
|
860
|
+
) -> Dict[str, Any]:
|
|
861
|
+
"""
|
|
862
|
+
Build rich schema context for LLM query generation.
|
|
863
|
+
|
|
864
|
+
This is a critical method that determines query generation quality.
|
|
865
|
+
It intelligently selects relevant schema elements based on the natural language query.
|
|
866
|
+
"""
|
|
867
|
+
# Use vector similarity or keyword matching to find relevant tables
|
|
868
|
+
relevant_tables = self._find_relevant_tables(schema_metadata, natural_language_query)
|
|
869
|
+
|
|
870
|
+
# Build comprehensive context including relationships, constraints, and sample data
|
|
871
|
+
context = {
|
|
872
|
+
"database_flavor": schema_metadata.database_flavor.value,
|
|
873
|
+
"schema_name": schema_metadata.schema_name,
|
|
874
|
+
"relevant_tables": relevant_tables,
|
|
875
|
+
"table_relationships": self._extract_table_relationships(schema_metadata, relevant_tables),
|
|
876
|
+
"common_patterns": self._get_query_patterns_for_tables(relevant_tables),
|
|
877
|
+
"data_types_guide": self._get_data_type_guide(schema_metadata.database_flavor)
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
return context
|
|
881
|
+
|
|
882
|
+
async def _execute_query_with_asyncdb(
|
|
883
|
+
self,
|
|
884
|
+
db_connection: AsyncDB,
|
|
885
|
+
sql_query: str,
|
|
886
|
+
max_rows: int,
|
|
887
|
+
timeout_seconds: int
|
|
888
|
+
) -> Any:
|
|
889
|
+
"""Execute query using AsyncDB with proper error handling and limits."""
|
|
890
|
+
# This integrates your existing DatabaseQueryTool execution logic
|
|
891
|
+
# but with enhanced error handling and result limiting
|
|
892
|
+
|
|
893
|
+
try:
|
|
894
|
+
# Add LIMIT clause if not present and max_rows is specified
|
|
895
|
+
if max_rows > 0 and "LIMIT" not in sql_query.upper():
|
|
896
|
+
sql_query = f"{sql_query.rstrip(';')} LIMIT {max_rows};"
|
|
897
|
+
|
|
898
|
+
# Execute with timeout using asyncio
|
|
899
|
+
async with await db_connection.connection() as conn:
|
|
900
|
+
return await asyncio.wait_for(
|
|
901
|
+
conn.fetchall(sql_query),
|
|
902
|
+
timeout=timeout_seconds
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
except asyncio.TimeoutError as e:
|
|
906
|
+
raise TimeoutError(
|
|
907
|
+
f"Query execution timed out after {timeout_seconds} seconds"
|
|
908
|
+
) from e
|
|
909
|
+
except Exception as e:
|
|
910
|
+
raise RuntimeError(
|
|
911
|
+
f"Database execution error: {str(e)}"
|
|
912
|
+
) from e
|
|
913
|
+
|
|
914
|
+
async def _format_query_results(
|
|
915
|
+
self,
|
|
916
|
+
raw_results: Any,
|
|
917
|
+
output_format: OutputFormat,
|
|
918
|
+
structured_output_schema: Optional[Dict[str, Any]]
|
|
919
|
+
) -> Any:
|
|
920
|
+
"""Format query results according to specified output format."""
|
|
921
|
+
if output_format == OutputFormat.PANDAS:
|
|
922
|
+
return pd.DataFrame(raw_results) if raw_results else pd.DataFrame()
|
|
923
|
+
elif output_format == OutputFormat.JSON:
|
|
924
|
+
return json.dumps(raw_results, default=str, indent=2)
|
|
925
|
+
elif output_format == OutputFormat.DICT:
|
|
926
|
+
return raw_results
|
|
927
|
+
elif output_format == OutputFormat.CSV:
|
|
928
|
+
df = pd.DataFrame(raw_results) if raw_results else pd.DataFrame()
|
|
929
|
+
return df.to_csv(index=False)
|
|
930
|
+
elif output_format == OutputFormat.STRUCTURED and structured_output_schema:
|
|
931
|
+
# Convert results to Pydantic models based on provided schema
|
|
932
|
+
return self._convert_to_structured_output(raw_results, structured_output_schema)
|
|
933
|
+
else:
|
|
934
|
+
return raw_results
|
|
935
|
+
|
|
936
|
+
# Database-specific implementations (these would replace your current separate tools)
|
|
937
|
+
async def _generate_postgresql_query(self, natural_language: str, schema_context: Dict) -> str:
|
|
938
|
+
"""
|
|
939
|
+
Generate PostgreSQL-specific SQL from natural language.
|
|
940
|
+
|
|
941
|
+
This method would integrate your existing SQLAgent logic but with enhanced
|
|
942
|
+
schema context and PostgreSQL-specific optimizations.
|
|
943
|
+
"""
|
|
944
|
+
# Build prompt with rich schema context
|
|
945
|
+
prompt = self._build_query_generation_prompt(
|
|
946
|
+
natural_language, schema_context, "postgresql"
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
# Use your existing LLM client to generate the query
|
|
950
|
+
# This would integrate with your AI-Parrot LLM clients
|
|
951
|
+
return await self._call_llm_for_query_generation(prompt)
|
|
952
|
+
|
|
953
|
+
async def _validate_postgresql_query(self, query: str) -> QueryValidationResult:
|
|
954
|
+
"""
|
|
955
|
+
Validate PostgreSQL query for syntax, security, and performance.
|
|
956
|
+
|
|
957
|
+
This provides the validation layer that was missing from your current SQLAgent.
|
|
958
|
+
"""
|
|
959
|
+
validation_result = QueryValidationResult(
|
|
960
|
+
is_valid=True,
|
|
961
|
+
query_type=None,
|
|
962
|
+
affected_tables=[],
|
|
963
|
+
estimated_cost=None,
|
|
964
|
+
warnings=[],
|
|
965
|
+
errors=[],
|
|
966
|
+
security_checks={}
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
try:
|
|
970
|
+
# Parse query to determine type and affected tables
|
|
971
|
+
query_upper = query.strip().upper()
|
|
972
|
+
if query_upper.startswith('SELECT'):
|
|
973
|
+
validation_result.query_type = QueryType.SELECT
|
|
974
|
+
elif query_upper.startswith('INSERT'):
|
|
975
|
+
validation_result.query_type = QueryType.INSERT
|
|
976
|
+
# ... other query types
|
|
977
|
+
|
|
978
|
+
# Security checks
|
|
979
|
+
validation_result.security_checks = {
|
|
980
|
+
"no_sql_injection_patterns": self._check_sql_injection_patterns(query),
|
|
981
|
+
"no_dangerous_operations": self._check_dangerous_operations(query),
|
|
982
|
+
"proper_quoting": self._check_proper_quoting(query)
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
# Syntax validation (could use sqlparse or connect to database for EXPLAIN)
|
|
986
|
+
syntax_valid = await self._validate_syntax_postgresql(query)
|
|
987
|
+
if not syntax_valid:
|
|
988
|
+
validation_result.is_valid = False
|
|
989
|
+
validation_result.errors.append("Invalid SQL syntax")
|
|
990
|
+
|
|
991
|
+
# Performance warnings
|
|
992
|
+
if "SELECT *" in query_upper:
|
|
993
|
+
validation_result.warnings.append("Consider specifying explicit columns instead of SELECT *")
|
|
994
|
+
|
|
995
|
+
return validation_result
|
|
996
|
+
|
|
997
|
+
except Exception as e:
|
|
998
|
+
validation_result.is_valid = False
|
|
999
|
+
validation_result.errors.append(f"Validation error: {str(e)}")
|
|
1000
|
+
return validation_result
|
|
1001
|
+
|
|
1002
|
+
def _check_dangerous_operations(self, query: str) -> bool:
|
|
1003
|
+
"""
|
|
1004
|
+
Check if query contains dangerous operations that should be blocked.
|
|
1005
|
+
|
|
1006
|
+
Returns:
|
|
1007
|
+
True if query is SAFE (no dangerous operations)
|
|
1008
|
+
False if dangerous operations detected
|
|
1009
|
+
"""
|
|
1010
|
+
query_upper = query.upper()
|
|
1011
|
+
|
|
1012
|
+
dangerous_patterns = [
|
|
1013
|
+
# DDL operations
|
|
1014
|
+
r'\bDROP\s+(TABLE|DATABASE|SCHEMA|INDEX|VIEW|PROCEDURE|FUNCTION)\b',
|
|
1015
|
+
r'\bTRUNCATE\s+TABLE\b',
|
|
1016
|
+
r'\bALTER\s+(TABLE|DATABASE|SCHEMA)\s+.*\s+DROP\b',
|
|
1017
|
+
# DML without WHERE
|
|
1018
|
+
r'\bDELETE\s+FROM\s+\w+\s*;?\s*$',
|
|
1019
|
+
# Admin commands
|
|
1020
|
+
r'\bGRANT\b',
|
|
1021
|
+
r'\bREVOKE\b',
|
|
1022
|
+
r'\bCREATE\s+USER\b',
|
|
1023
|
+
r'\bDROP\s+USER\b',
|
|
1024
|
+
r'\bALTER\s+USER\b',
|
|
1025
|
+
# Command execution (PostgreSQL)
|
|
1026
|
+
r'\bCOPY\s+.*\s+TO\s+PROGRAM\b',
|
|
1027
|
+
# SQL Server
|
|
1028
|
+
r'\bEXEC\s*\(',
|
|
1029
|
+
r'\bXP_CMDSHELL\b',
|
|
1030
|
+
# MySQL file operations
|
|
1031
|
+
r'\bLOAD_FILE\b',
|
|
1032
|
+
r'\bINTO\s+OUTFILE\b',
|
|
1033
|
+
r'\bINTO\s+DUMPFILE\b',
|
|
1034
|
+
]
|
|
1035
|
+
|
|
1036
|
+
for pattern in dangerous_patterns:
|
|
1037
|
+
if re.search(pattern, query_upper, re.IGNORECASE | re.DOTALL):
|
|
1038
|
+
return False
|
|
1039
|
+
|
|
1040
|
+
# Check DELETE/UPDATE without WHERE
|
|
1041
|
+
if re.search(r'\bDELETE\s+FROM\s+\w+\s*$', query_upper):
|
|
1042
|
+
return False
|
|
1043
|
+
|
|
1044
|
+
update_match = re.search(r'\bUPDATE\s+\w+\s+SET\s+', query_upper)
|
|
1045
|
+
if update_match and 'WHERE' not in query_upper:
|
|
1046
|
+
return False
|
|
1047
|
+
|
|
1048
|
+
return True
|
|
1049
|
+
|
|
1050
|
+
def _check_sql_injection_patterns(self, query: str) -> bool:
|
|
1051
|
+
"""
|
|
1052
|
+
Check for common SQL injection patterns.
|
|
1053
|
+
|
|
1054
|
+
Returns:
|
|
1055
|
+
True if no injection patterns found (SAFE)
|
|
1056
|
+
False if potential injection detected
|
|
1057
|
+
"""
|
|
1058
|
+
injection_patterns = [
|
|
1059
|
+
# Union/Boolean-based
|
|
1060
|
+
r"'\s*(OR|AND)\s+['\"0-9]",
|
|
1061
|
+
r"'\s*OR\s+1\s*=\s*1",
|
|
1062
|
+
r"'\s*OR\s+'[^']*'\s*=\s*'[^']*'",
|
|
1063
|
+
# Comment-based
|
|
1064
|
+
r";\s*--",
|
|
1065
|
+
r";\s*/\*",
|
|
1066
|
+
r"--\s*$",
|
|
1067
|
+
# Stacked queries
|
|
1068
|
+
r"'\s*;\s*(DROP|DELETE|UPDATE|INSERT|EXEC)\b",
|
|
1069
|
+
r";\s*(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER)\b",
|
|
1070
|
+
# UNION injection
|
|
1071
|
+
r"\bUNION\s+(ALL\s+)?SELECT\b.*\bFROM\b",
|
|
1072
|
+
# Time-based
|
|
1073
|
+
r"\bSLEEP\s*\(",
|
|
1074
|
+
r"\bWAITFOR\s+DELAY\b",
|
|
1075
|
+
r"\bBENCHMARK\s*\(",
|
|
1076
|
+
r"\bPG_SLEEP\s*\(",
|
|
1077
|
+
# Encoding attempts
|
|
1078
|
+
r"0x[0-9a-fA-F]+",
|
|
1079
|
+
r"\bCHAR\s*\(\s*\d+\s*\)",
|
|
1080
|
+
]
|
|
1081
|
+
|
|
1082
|
+
for pattern in injection_patterns:
|
|
1083
|
+
if re.search(pattern, query, re.IGNORECASE):
|
|
1084
|
+
return False
|
|
1085
|
+
|
|
1086
|
+
return True
|
|
1087
|
+
|
|
1088
|
+
def _check_proper_quoting(self, query: str) -> bool:
|
|
1089
|
+
"""
|
|
1090
|
+
Check if string literals are properly quoted.
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
True if quoting appears proper (SAFE)
|
|
1094
|
+
False if improper quoting detected
|
|
1095
|
+
"""
|
|
1096
|
+
# Check for unbalanced quotes
|
|
1097
|
+
single_quotes = query.count("'") - query.count("\\'") - query.count("''")
|
|
1098
|
+
double_quotes = query.count('"') - query.count('\\"') - query.count('""')
|
|
1099
|
+
|
|
1100
|
+
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
|
1101
|
+
return False
|
|
1102
|
+
|
|
1103
|
+
# Check for dangerous patterns after string literals
|
|
1104
|
+
dangerous_patterns = [
|
|
1105
|
+
r"'\s*\)\s*(OR|AND|UNION)\b",
|
|
1106
|
+
r"'\s*;\s*\w+",
|
|
1107
|
+
]
|
|
1108
|
+
|
|
1109
|
+
for pattern in dangerous_patterns:
|
|
1110
|
+
if re.search(pattern, query, re.IGNORECASE):
|
|
1111
|
+
return False
|
|
1112
|
+
|
|
1113
|
+
return True
|
|
1114
|
+
|
|
1115
|
+
async def _validate_syntax_postgresql(self, query: str) -> bool:
|
|
1116
|
+
"""Validate PostgreSQL query syntax using pattern matching."""
|
|
1117
|
+
try:
|
|
1118
|
+
query_stripped = query.strip().rstrip(';')
|
|
1119
|
+
query_upper = query_stripped.upper()
|
|
1120
|
+
|
|
1121
|
+
valid_starts = [
|
|
1122
|
+
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'WITH',
|
|
1123
|
+
'CREATE', 'ALTER', 'DROP', 'TRUNCATE',
|
|
1124
|
+
'BEGIN', 'COMMIT', 'ROLLBACK', 'SAVEPOINT',
|
|
1125
|
+
'EXPLAIN', 'ANALYZE', 'VACUUM', 'REINDEX',
|
|
1126
|
+
'GRANT', 'REVOKE', 'SET', 'SHOW', 'RESET',
|
|
1127
|
+
'COPY', 'CALL', 'DO', 'LOCK',
|
|
1128
|
+
]
|
|
1129
|
+
|
|
1130
|
+
first_word = query_upper.split()[0] if query_stripped else ''
|
|
1131
|
+
if first_word not in valid_starts:
|
|
1132
|
+
return False
|
|
1133
|
+
|
|
1134
|
+
# Validate statement structures
|
|
1135
|
+
if query_upper.startswith('INSERT') and 'INTO' not in query_upper:
|
|
1136
|
+
return False
|
|
1137
|
+
if query_upper.startswith('UPDATE') and 'SET' not in query_upper:
|
|
1138
|
+
return False
|
|
1139
|
+
if query_upper.startswith('DELETE') and 'FROM' not in query_upper:
|
|
1140
|
+
return False
|
|
1141
|
+
|
|
1142
|
+
# Check balance
|
|
1143
|
+
if query.count('(') != query.count(')'):
|
|
1144
|
+
return False
|
|
1145
|
+
if query.count('[') != query.count(']'):
|
|
1146
|
+
return False
|
|
1147
|
+
|
|
1148
|
+
return True
|
|
1149
|
+
except Exception:
|
|
1150
|
+
return False
|
|
1151
|
+
|
|
1152
|
+
# =========================================================================
|
|
1153
|
+
# MYSQL METHODS
|
|
1154
|
+
# =========================================================================
|
|
1155
|
+
|
|
1156
|
+
async def _generate_mysql_query(
|
|
1157
|
+
self,
|
|
1158
|
+
natural_language: str,
|
|
1159
|
+
schema_context: Dict[str, Any]
|
|
1160
|
+
) -> str:
|
|
1161
|
+
"""Generate MySQL-specific SQL from natural language."""
|
|
1162
|
+
prompt = self._build_query_generation_prompt(
|
|
1163
|
+
natural_language, schema_context, "mysql"
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
mysql_instructions = """
|
|
1167
|
+
MySQL-Specific Rules:
|
|
1168
|
+
1. Use backticks (`) for identifier quoting
|
|
1169
|
+
2. Use LIMIT for row limiting
|
|
1170
|
+
3. Use IFNULL() instead of COALESCE() for two arguments
|
|
1171
|
+
4. Use NOW() for current timestamp
|
|
1172
|
+
5. Use DATE_FORMAT() for date formatting
|
|
1173
|
+
6. Use CONCAT() for string concatenation
|
|
1174
|
+
7. Boolean values are 1/0
|
|
1175
|
+
8. Use REGEXP for regex matching
|
|
1176
|
+
"""
|
|
1177
|
+
|
|
1178
|
+
generated_query = await self._call_llm_for_query_generation(
|
|
1179
|
+
f"{prompt}\n\n{mysql_instructions}"
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
return self._ensure_mysql_compatibility(generated_query)
|
|
1183
|
+
|
|
1184
|
+
def _ensure_mysql_compatibility(self, query: str) -> str:
|
|
1185
|
+
"""Post-process query to ensure MySQL compatibility."""
|
|
1186
|
+
result = query
|
|
1187
|
+
|
|
1188
|
+
# Replace double quotes with backticks for identifiers
|
|
1189
|
+
result = re.sub(r'"(\w+)"(?=\s*[,.\)\s]|$)', r'`\1`', result)
|
|
1190
|
+
|
|
1191
|
+
# Replace COALESCE with two args to IFNULL
|
|
1192
|
+
result = re.sub(
|
|
1193
|
+
r'\bCOALESCE\s*\(\s*([^,]+)\s*,\s*([^,\)]+)\s*\)',
|
|
1194
|
+
r'IFNULL(\1, \2)',
|
|
1195
|
+
result,
|
|
1196
|
+
flags=re.IGNORECASE
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
# Replace TRUE/FALSE with 1/0
|
|
1200
|
+
result = re.sub(r'\bTRUE\b', '1', result, flags=re.IGNORECASE)
|
|
1201
|
+
result = re.sub(r'\bFALSE\b', '0', result, flags=re.IGNORECASE)
|
|
1202
|
+
|
|
1203
|
+
return result
|
|
1204
|
+
|
|
1205
|
+
async def _validate_mysql_query(self, query: str) -> QueryValidationResult:
|
|
1206
|
+
"""Validate MySQL query for syntax, security, and performance."""
|
|
1207
|
+
validation_result = QueryValidationResult(
|
|
1208
|
+
is_valid=True,
|
|
1209
|
+
query_type=None,
|
|
1210
|
+
affected_tables=[],
|
|
1211
|
+
estimated_cost=None,
|
|
1212
|
+
warnings=[],
|
|
1213
|
+
errors=[],
|
|
1214
|
+
security_checks={}
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
try:
|
|
1218
|
+
query_upper = query.strip().upper()
|
|
1219
|
+
|
|
1220
|
+
# Determine query type
|
|
1221
|
+
for qt in QueryType:
|
|
1222
|
+
if query_upper.startswith(qt.value):
|
|
1223
|
+
validation_result.query_type = qt
|
|
1224
|
+
break
|
|
1225
|
+
|
|
1226
|
+
# Extract tables
|
|
1227
|
+
validation_result.affected_tables = self._extract_tables_from_query(query)
|
|
1228
|
+
|
|
1229
|
+
# Security checks
|
|
1230
|
+
validation_result.security_checks = {
|
|
1231
|
+
"no_sql_injection_patterns": self._check_sql_injection_patterns(query),
|
|
1232
|
+
"no_dangerous_operations": self._check_dangerous_operations(query),
|
|
1233
|
+
"proper_quoting": self._check_proper_quoting(query)
|
|
1234
|
+
}
|
|
1235
|
+
|
|
1236
|
+
if not all(validation_result.security_checks.values()):
|
|
1237
|
+
validation_result.is_valid = False
|
|
1238
|
+
for check, passed in validation_result.security_checks.items():
|
|
1239
|
+
if not passed:
|
|
1240
|
+
validation_result.errors.append(f"Security check failed: {check}")
|
|
1241
|
+
|
|
1242
|
+
# Syntax validation
|
|
1243
|
+
if not await self._validate_syntax_mysql(query):
|
|
1244
|
+
validation_result.is_valid = False
|
|
1245
|
+
validation_result.errors.append("Invalid MySQL syntax")
|
|
1246
|
+
|
|
1247
|
+
# Performance warnings
|
|
1248
|
+
if "SELECT *" in query_upper:
|
|
1249
|
+
validation_result.warnings.append("Consider specifying explicit columns")
|
|
1250
|
+
|
|
1251
|
+
if re.search(r"LIKE\s*'%", query, re.IGNORECASE):
|
|
1252
|
+
validation_result.warnings.append("Leading wildcard may prevent index usage")
|
|
1253
|
+
|
|
1254
|
+
return validation_result
|
|
1255
|
+
|
|
1256
|
+
except Exception as e:
|
|
1257
|
+
validation_result.is_valid = False
|
|
1258
|
+
validation_result.errors.append(f"Validation error: {str(e)}")
|
|
1259
|
+
return validation_result
|
|
1260
|
+
|
|
1261
|
+
async def _validate_syntax_mysql(self, query: str) -> bool:
|
|
1262
|
+
"""Validate MySQL-specific query syntax."""
|
|
1263
|
+
try:
|
|
1264
|
+
query_stripped = query.strip().rstrip(';')
|
|
1265
|
+
query_upper = query_stripped.upper()
|
|
1266
|
+
|
|
1267
|
+
valid_starts = [
|
|
1268
|
+
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'REPLACE',
|
|
1269
|
+
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME',
|
|
1270
|
+
'START', 'BEGIN', 'COMMIT', 'ROLLBACK', 'SAVEPOINT',
|
|
1271
|
+
'SET', 'SHOW', 'DESCRIBE', 'DESC', 'EXPLAIN',
|
|
1272
|
+
'GRANT', 'REVOKE', 'LOCK', 'UNLOCK', 'USE', 'WITH',
|
|
1273
|
+
]
|
|
1274
|
+
|
|
1275
|
+
first_word = query_upper.split()[0] if query_stripped else ''
|
|
1276
|
+
if first_word not in valid_starts:
|
|
1277
|
+
return False
|
|
1278
|
+
|
|
1279
|
+
# Validate statement structures
|
|
1280
|
+
if first_word in ('INSERT', 'REPLACE') and 'INTO' not in query_upper:
|
|
1281
|
+
return False
|
|
1282
|
+
if first_word == 'UPDATE' and 'SET' not in query_upper:
|
|
1283
|
+
return False
|
|
1284
|
+
|
|
1285
|
+
# Check balance
|
|
1286
|
+
if query.count('(') != query.count(')'):
|
|
1287
|
+
return False
|
|
1288
|
+
if query.count('`') % 2 != 0:
|
|
1289
|
+
return False
|
|
1290
|
+
|
|
1291
|
+
return True
|
|
1292
|
+
except Exception:
|
|
1293
|
+
return False
|
|
1294
|
+
|
|
1295
|
+
# =========================================================================
|
|
1296
|
+
# BIGQUERY METHODS
|
|
1297
|
+
# =========================================================================
|
|
1298
|
+
|
|
1299
|
+
async def _generate_bigquery_query(
|
|
1300
|
+
self,
|
|
1301
|
+
natural_language: str,
|
|
1302
|
+
schema_context: Dict[str, Any]
|
|
1303
|
+
) -> str:
|
|
1304
|
+
"""Generate BigQuery-specific SQL from natural language."""
|
|
1305
|
+
prompt = self._build_query_generation_prompt(
|
|
1306
|
+
natural_language, schema_context, "bigquery"
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
bigquery_instructions = """
|
|
1310
|
+
BigQuery-Specific Rules:
|
|
1311
|
+
1. Use backticks for table names: `project.dataset.table`
|
|
1312
|
+
2. Use STRUCT<> and ARRAY<> for complex types
|
|
1313
|
+
3. Use UNNEST() to flatten arrays
|
|
1314
|
+
4. Use SAFE_DIVIDE() for division with potential zeros
|
|
1315
|
+
5. Use FORMAT_DATE/FORMAT_TIMESTAMP for date formatting
|
|
1316
|
+
6. Use DATE_DIFF, TIMESTAMP_DIFF for date differences
|
|
1317
|
+
7. Use QUALIFY clause for window function filtering
|
|
1318
|
+
8. Use Standard SQL (prefix with #standardSQL if needed)
|
|
1319
|
+
"""
|
|
1320
|
+
|
|
1321
|
+
generated_query = await self._call_llm_for_query_generation(
|
|
1322
|
+
f"{prompt}\n\n{bigquery_instructions}"
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
return self._ensure_bigquery_compatibility(generated_query)
|
|
1326
|
+
|
|
1327
|
+
def _ensure_bigquery_compatibility(self, query: str) -> str:
|
|
1328
|
+
"""Post-process query to ensure BigQuery compatibility."""
|
|
1329
|
+
result = query
|
|
1330
|
+
|
|
1331
|
+
# Quote project.dataset.table names
|
|
1332
|
+
result = re.sub(
|
|
1333
|
+
r'(?<![`\w])(\w+)\.(\w+)\.(\w+)(?![`\w])',
|
|
1334
|
+
r'`\1.\2.\3`',
|
|
1335
|
+
result
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
# Replace NOW() with CURRENT_TIMESTAMP()
|
|
1339
|
+
result = re.sub(r'\bNOW\s*\(\s*\)', 'CURRENT_TIMESTAMP()', result, flags=re.IGNORECASE)
|
|
1340
|
+
|
|
1341
|
+
# Replace GETDATE()
|
|
1342
|
+
result = re.sub(r'\bGETDATE\s*\(\s*\)', 'CURRENT_TIMESTAMP()', result, flags=re.IGNORECASE)
|
|
1343
|
+
|
|
1344
|
+
return result
|
|
1345
|
+
|
|
1346
|
+
async def _validate_bigquery_query(self, query: str) -> QueryValidationResult:
|
|
1347
|
+
"""Validate BigQuery query for syntax, security, and performance."""
|
|
1348
|
+
validation_result = QueryValidationResult(
|
|
1349
|
+
is_valid=True,
|
|
1350
|
+
query_type=None,
|
|
1351
|
+
affected_tables=[],
|
|
1352
|
+
estimated_cost=None,
|
|
1353
|
+
warnings=[],
|
|
1354
|
+
errors=[],
|
|
1355
|
+
security_checks={}
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
try:
|
|
1359
|
+
query_upper = query.strip().upper()
|
|
1360
|
+
|
|
1361
|
+
# Remove SQL dialect prefix
|
|
1362
|
+
if query_upper.startswith('#'):
|
|
1363
|
+
newline_idx = query.find('\n')
|
|
1364
|
+
if newline_idx > 0:
|
|
1365
|
+
query_upper = query[newline_idx:].strip().upper()
|
|
1366
|
+
|
|
1367
|
+
# Determine query type
|
|
1368
|
+
if query_upper.startswith(('SELECT', 'WITH')):
|
|
1369
|
+
validation_result.query_type = QueryType.SELECT
|
|
1370
|
+
elif query_upper.startswith('MERGE'):
|
|
1371
|
+
validation_result.query_type = QueryType.UPDATE
|
|
1372
|
+
else:
|
|
1373
|
+
for qt in QueryType:
|
|
1374
|
+
if query_upper.startswith(qt.value):
|
|
1375
|
+
validation_result.query_type = qt
|
|
1376
|
+
break
|
|
1377
|
+
|
|
1378
|
+
# Extract tables
|
|
1379
|
+
validation_result.affected_tables = self._extract_bigquery_tables(query)
|
|
1380
|
+
|
|
1381
|
+
# Security checks
|
|
1382
|
+
validation_result.security_checks = {
|
|
1383
|
+
"no_sql_injection_patterns": self._check_sql_injection_patterns(query),
|
|
1384
|
+
"no_dangerous_operations": self._check_dangerous_bigquery_operations(query),
|
|
1385
|
+
"proper_quoting": self._check_proper_quoting(query)
|
|
1386
|
+
}
|
|
1387
|
+
|
|
1388
|
+
if not all(validation_result.security_checks.values()):
|
|
1389
|
+
validation_result.is_valid = False
|
|
1390
|
+
for check, passed in validation_result.security_checks.items():
|
|
1391
|
+
if not passed:
|
|
1392
|
+
validation_result.errors.append(f"Security check failed: {check}")
|
|
1393
|
+
|
|
1394
|
+
# Syntax validation
|
|
1395
|
+
if not await self._validate_syntax_bigquery(query):
|
|
1396
|
+
validation_result.is_valid = False
|
|
1397
|
+
validation_result.errors.append("Invalid BigQuery SQL syntax")
|
|
1398
|
+
|
|
1399
|
+
# Performance warnings
|
|
1400
|
+
if "SELECT *" in query_upper:
|
|
1401
|
+
validation_result.warnings.append(
|
|
1402
|
+
"SELECT * scans all columns - specify needed columns for cost reduction"
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
if 'WHERE' not in query_upper and '_PARTITIONTIME' not in query_upper:
|
|
1406
|
+
validation_result.warnings.append(
|
|
1407
|
+
"Consider adding partition filter for cost reduction"
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
if query.strip().startswith('#legacySQL'):
|
|
1411
|
+
validation_result.warnings.append("Consider migrating to Standard SQL")
|
|
1412
|
+
|
|
1413
|
+
return validation_result
|
|
1414
|
+
|
|
1415
|
+
except Exception as e:
|
|
1416
|
+
validation_result.is_valid = False
|
|
1417
|
+
validation_result.errors.append(f"Validation error: {str(e)}")
|
|
1418
|
+
return validation_result
|
|
1419
|
+
|
|
1420
|
+
def _check_dangerous_bigquery_operations(self, query: str) -> bool:
|
|
1421
|
+
"""Check for dangerous BigQuery operations. Returns True if SAFE."""
|
|
1422
|
+
query_upper = query.upper()
|
|
1423
|
+
|
|
1424
|
+
dangerous_patterns = [
|
|
1425
|
+
r'\bDROP\s+(TABLE|SCHEMA|VIEW|FUNCTION)\b',
|
|
1426
|
+
r'\bTRUNCATE\s+TABLE\b',
|
|
1427
|
+
r'\bDELETE\s+FROM\s+`[^`]+`\s*$',
|
|
1428
|
+
r'\bDROP\s+ALL\s+ROW\s+ACCESS\s+POLICIES\b',
|
|
1429
|
+
]
|
|
1430
|
+
|
|
1431
|
+
for pattern in dangerous_patterns:
|
|
1432
|
+
if re.search(pattern, query_upper, re.IGNORECASE | re.DOTALL):
|
|
1433
|
+
return False
|
|
1434
|
+
|
|
1435
|
+
return True
|
|
1436
|
+
|
|
1437
|
+
def _extract_bigquery_tables(self, query: str) -> List[str]:
|
|
1438
|
+
"""Extract table names from BigQuery query."""
|
|
1439
|
+
tables = set()
|
|
1440
|
+
|
|
1441
|
+
# Backtick-quoted fully-qualified names
|
|
1442
|
+
tables.update(re.findall(r'`([a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)`', query))
|
|
1443
|
+
tables.update(re.findall(r'`([a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)`', query))
|
|
1444
|
+
|
|
1445
|
+
# Standard table references
|
|
1446
|
+
tables.update(self._extract_tables_from_query(query))
|
|
1447
|
+
|
|
1448
|
+
return list(tables)
|
|
1449
|
+
|
|
1450
|
+
async def _validate_syntax_bigquery(self, query: str) -> bool:
|
|
1451
|
+
"""Validate BigQuery-specific query syntax."""
|
|
1452
|
+
try:
|
|
1453
|
+
query_stripped = query.strip()
|
|
1454
|
+
|
|
1455
|
+
if query_stripped.startswith('#'):
|
|
1456
|
+
newline_idx = query_stripped.find('\n')
|
|
1457
|
+
if newline_idx > 0:
|
|
1458
|
+
query_stripped = query_stripped[newline_idx:].strip()
|
|
1459
|
+
|
|
1460
|
+
query_stripped = query_stripped.rstrip(';')
|
|
1461
|
+
query_upper = query_stripped.upper()
|
|
1462
|
+
|
|
1463
|
+
valid_starts = [
|
|
1464
|
+
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE',
|
|
1465
|
+
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'WITH',
|
|
1466
|
+
'DECLARE', 'SET', 'EXECUTE', 'BEGIN', 'IF',
|
|
1467
|
+
'EXPORT', 'LOAD', 'GRANT', 'REVOKE', 'ASSERT',
|
|
1468
|
+
]
|
|
1469
|
+
|
|
1470
|
+
first_word = query_upper.split()[0] if query_stripped else ''
|
|
1471
|
+
if first_word not in valid_starts:
|
|
1472
|
+
return False
|
|
1473
|
+
|
|
1474
|
+
# Check balance
|
|
1475
|
+
if query.count('(') != query.count(')'):
|
|
1476
|
+
return False
|
|
1477
|
+
if query.count('`') % 2 != 0:
|
|
1478
|
+
return False
|
|
1479
|
+
if query.count('[') != query.count(']'):
|
|
1480
|
+
return False
|
|
1481
|
+
if query.count('<') != query.count('>'):
|
|
1482
|
+
return False
|
|
1483
|
+
|
|
1484
|
+
return True
|
|
1485
|
+
except Exception:
|
|
1486
|
+
return False
|
|
1487
|
+
|
|
1488
|
+
# =========================================================================
|
|
1489
|
+
# HELPER METHODS
|
|
1490
|
+
# =========================================================================
|
|
1491
|
+
|
|
1492
|
+
def _extract_tables_from_query(self, query: str) -> List[str]:
|
|
1493
|
+
"""Extract table names from SQL query."""
|
|
1494
|
+
tables = set()
|
|
1495
|
+
|
|
1496
|
+
patterns = [
|
|
1497
|
+
r'\bFROM\s+([`"\[]?[\w.-]+[`"\]]?)',
|
|
1498
|
+
r'\bJOIN\s+([`"\[]?[\w.-]+[`"\]]?)',
|
|
1499
|
+
r'\bINSERT\s+INTO\s+([`"\[]?[\w.-]+[`"\]]?)',
|
|
1500
|
+
r'\bUPDATE\s+([`"\[]?[\w.-]+[`"\]]?)',
|
|
1501
|
+
r'\bDELETE\s+FROM\s+([`"\[]?[\w.-]+[`"\]]?)',
|
|
1502
|
+
]
|
|
1503
|
+
|
|
1504
|
+
for pattern in patterns:
|
|
1505
|
+
for match in re.findall(pattern, query, re.IGNORECASE):
|
|
1506
|
+
tables.add(match.strip('`"[]'))
|
|
1507
|
+
|
|
1508
|
+
return list(tables)
|
|
1509
|
+
|
|
1510
|
+
def _find_relevant_tables(
|
|
1511
|
+
self,
|
|
1512
|
+
schema_metadata: SchemaMetadata,
|
|
1513
|
+
natural_language_query: str
|
|
1514
|
+
) -> List[Dict[str, Any]]:
|
|
1515
|
+
"""Find tables relevant to the natural language query."""
|
|
1516
|
+
relevant_tables = []
|
|
1517
|
+
query_lower = natural_language_query.lower()
|
|
1518
|
+
|
|
1519
|
+
keywords = set(re.findall(r'\b\w+\b', query_lower))
|
|
1520
|
+
|
|
1521
|
+
stop_words = {
|
|
1522
|
+
'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
|
|
1523
|
+
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
|
|
1524
|
+
'could', 'should', 'may', 'might', 'must', 'shall', 'can',
|
|
1525
|
+
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
|
|
1526
|
+
'and', 'or', 'but', 'if', 'as', 'show', 'get', 'find', 'list',
|
|
1527
|
+
'give', 'tell', 'me', 'i', 'you', 'we', 'they', 'select', 'all',
|
|
1528
|
+
}
|
|
1529
|
+
|
|
1530
|
+
keywords = keywords - stop_words
|
|
1531
|
+
|
|
1532
|
+
for table in schema_metadata.tables:
|
|
1533
|
+
table_name = table.get('table_name', '').lower()
|
|
1534
|
+
columns = table.get('columns', [])
|
|
1535
|
+
|
|
1536
|
+
score = 0
|
|
1537
|
+
matched_columns = []
|
|
1538
|
+
|
|
1539
|
+
# Check table name
|
|
1540
|
+
table_words = set(re.findall(r'\w+', table_name))
|
|
1541
|
+
if table_words & keywords:
|
|
1542
|
+
score += 10
|
|
1543
|
+
|
|
1544
|
+
for keyword in keywords:
|
|
1545
|
+
if keyword in table_name:
|
|
1546
|
+
score += 5
|
|
1547
|
+
|
|
1548
|
+
# Check columns
|
|
1549
|
+
for column in columns:
|
|
1550
|
+
col_name = column.get('column_name', '').lower()
|
|
1551
|
+
col_words = set(re.findall(r'\w+', col_name))
|
|
1552
|
+
|
|
1553
|
+
if col_words & keywords:
|
|
1554
|
+
score += 3
|
|
1555
|
+
matched_columns.append(col_name)
|
|
1556
|
+
|
|
1557
|
+
for keyword in keywords:
|
|
1558
|
+
if keyword in col_name and col_name not in matched_columns:
|
|
1559
|
+
score += 1
|
|
1560
|
+
matched_columns.append(col_name)
|
|
1561
|
+
|
|
1562
|
+
if score > 0:
|
|
1563
|
+
relevant_tables.append({
|
|
1564
|
+
'table_name': table.get('table_name'),
|
|
1565
|
+
'schema': table.get('schema', schema_metadata.schema_name),
|
|
1566
|
+
'columns': columns,
|
|
1567
|
+
'matched_columns': matched_columns,
|
|
1568
|
+
'relevance_score': score,
|
|
1569
|
+
'comment': table.get('comment', '')
|
|
1570
|
+
})
|
|
1571
|
+
|
|
1572
|
+
relevant_tables.sort(key=lambda x: x['relevance_score'], reverse=True)
|
|
1573
|
+
return relevant_tables[:10]
|
|
1574
|
+
|
|
1575
|
+
def _extract_table_relationships(
|
|
1576
|
+
self,
|
|
1577
|
+
schema_metadata: SchemaMetadata,
|
|
1578
|
+
relevant_tables: List[Dict[str, Any]]
|
|
1579
|
+
) -> List[Dict[str, Any]]:
|
|
1580
|
+
"""Extract relationships between relevant tables."""
|
|
1581
|
+
relationships = []
|
|
1582
|
+
relevant_table_names = {t['table_name'] for t in relevant_tables}
|
|
1583
|
+
|
|
1584
|
+
# From constraints
|
|
1585
|
+
for constraint in schema_metadata.constraints:
|
|
1586
|
+
if constraint.get('constraint_type') == 'FOREIGN KEY':
|
|
1587
|
+
source_table = constraint.get('table_name')
|
|
1588
|
+
target_table = constraint.get('referenced_table')
|
|
1589
|
+
|
|
1590
|
+
if source_table in relevant_table_names or target_table in relevant_table_names:
|
|
1591
|
+
relationships.append({
|
|
1592
|
+
'type': 'foreign_key',
|
|
1593
|
+
'source_table': source_table,
|
|
1594
|
+
'source_column': constraint.get('column_name'),
|
|
1595
|
+
'target_table': target_table,
|
|
1596
|
+
'target_column': constraint.get('referenced_column'),
|
|
1597
|
+
})
|
|
1598
|
+
|
|
1599
|
+
# Infer from naming conventions
|
|
1600
|
+
for table in relevant_tables:
|
|
1601
|
+
for column in table.get('columns', []):
|
|
1602
|
+
col_name = column.get('column_name', '')
|
|
1603
|
+
|
|
1604
|
+
if col_name.endswith('_id'):
|
|
1605
|
+
potential_table = col_name[:-3]
|
|
1606
|
+
for pt in [potential_table, potential_table + 's', potential_table + 'es']:
|
|
1607
|
+
if pt in relevant_table_names:
|
|
1608
|
+
relationships.append({
|
|
1609
|
+
'type': 'inferred',
|
|
1610
|
+
'source_table': table['table_name'],
|
|
1611
|
+
'source_column': col_name,
|
|
1612
|
+
'target_table': pt,
|
|
1613
|
+
'target_column': 'id',
|
|
1614
|
+
})
|
|
1615
|
+
break
|
|
1616
|
+
|
|
1617
|
+
return relationships
|
|
1618
|
+
|
|
1619
|
+
def _get_query_patterns_for_tables(
|
|
1620
|
+
self,
|
|
1621
|
+
relevant_tables: List[Dict[str, Any]]
|
|
1622
|
+
) -> List[Dict[str, Any]]:
|
|
1623
|
+
"""Generate common query patterns for relevant tables."""
|
|
1624
|
+
patterns = []
|
|
1625
|
+
|
|
1626
|
+
for table in relevant_tables[:3]:
|
|
1627
|
+
table_name = table['table_name']
|
|
1628
|
+
columns = table.get('columns', [])
|
|
1629
|
+
|
|
1630
|
+
if columns:
|
|
1631
|
+
col_list = ', '.join([c['column_name'] for c in columns[:5]])
|
|
1632
|
+
patterns.append({
|
|
1633
|
+
'description': f'Select from {table_name}',
|
|
1634
|
+
'pattern': f'SELECT {col_list} FROM {table_name} WHERE ...',
|
|
1635
|
+
})
|
|
1636
|
+
|
|
1637
|
+
if numeric_cols := [
|
|
1638
|
+
c for c in columns if c.get('data_type', '').lower() in (
|
|
1639
|
+
'integer', 'int', 'bigint', 'numeric', 'decimal', 'float'
|
|
1640
|
+
)
|
|
1641
|
+
]:
|
|
1642
|
+
patterns.append({
|
|
1643
|
+
'description': f'Aggregate {table_name}',
|
|
1644
|
+
'pattern': f'SELECT COUNT(*), SUM({numeric_cols[0]["column_name"]}) FROM {table_name} GROUP BY ...',
|
|
1645
|
+
})
|
|
1646
|
+
|
|
1647
|
+
return patterns
|
|
1648
|
+
|
|
1649
|
+
def _get_data_type_guide(self, database_flavor: DatabaseFlavor) -> Dict[str, Any]:
|
|
1650
|
+
"""Get data type information for database flavor."""
|
|
1651
|
+
guides = {
|
|
1652
|
+
DatabaseFlavor.POSTGRESQL: {
|
|
1653
|
+
'string_concat': '|| operator or CONCAT()',
|
|
1654
|
+
'null_handling': 'IS NULL / IS NOT NULL, COALESCE()',
|
|
1655
|
+
'boolean_type': 'BOOLEAN',
|
|
1656
|
+
},
|
|
1657
|
+
DatabaseFlavor.MYSQL: {
|
|
1658
|
+
'string_concat': 'CONCAT() function',
|
|
1659
|
+
'null_handling': 'IS NULL / IS NOT NULL, IFNULL()',
|
|
1660
|
+
'boolean_type': 'TINYINT(1)',
|
|
1661
|
+
},
|
|
1662
|
+
DatabaseFlavor.BIGQUERY: {
|
|
1663
|
+
'string_concat': 'CONCAT() or ||',
|
|
1664
|
+
'null_handling': 'IS NULL / IS NOT NULL, IFNULL(), COALESCE()',
|
|
1665
|
+
'boolean_type': 'BOOL',
|
|
1666
|
+
}
|
|
1667
|
+
}
|
|
1668
|
+
return guides.get(database_flavor, guides[DatabaseFlavor.POSTGRESQL])
|
|
1669
|
+
|
|
1670
|
+
def _build_query_generation_prompt(
|
|
1671
|
+
self,
|
|
1672
|
+
natural_language: str,
|
|
1673
|
+
schema_context: Dict[str, Any],
|
|
1674
|
+
dialect: str
|
|
1675
|
+
) -> str:
|
|
1676
|
+
"""Build prompt for LLM query generation."""
|
|
1677
|
+
prompt_parts = [
|
|
1678
|
+
f"Generate a {dialect.upper()} SQL query for:",
|
|
1679
|
+
f"\nRequest: {natural_language}",
|
|
1680
|
+
f"\nDatabase: {schema_context.get('database_flavor', dialect).upper()}",
|
|
1681
|
+
"\n\nAvailable Tables:",
|
|
1682
|
+
]
|
|
1683
|
+
|
|
1684
|
+
for table in schema_context.get('relevant_tables', [])[:5]:
|
|
1685
|
+
prompt_parts.append(f"\n\nTable: {table.get('table_name')}")
|
|
1686
|
+
columns = table.get('columns', [])[:15]
|
|
1687
|
+
if columns:
|
|
1688
|
+
prompt_parts.append("\nColumns:")
|
|
1689
|
+
for col in columns:
|
|
1690
|
+
col_info = f" - {col.get('column_name')}: {col.get('data_type', 'unknown')}"
|
|
1691
|
+
prompt_parts.append(col_info)
|
|
1692
|
+
|
|
1693
|
+
relationships = schema_context.get('table_relationships', [])
|
|
1694
|
+
if relationships:
|
|
1695
|
+
prompt_parts.append("\n\nRelationships:")
|
|
1696
|
+
for rel in relationships[:5]:
|
|
1697
|
+
prompt_parts.append(
|
|
1698
|
+
f" - {rel['source_table']}.{rel['source_column']} -> "
|
|
1699
|
+
f"{rel['target_table']}.{rel['target_column']}"
|
|
1700
|
+
)
|
|
1701
|
+
|
|
1702
|
+
prompt_parts.append("\n\nGenerate only the SQL query, no explanations:")
|
|
1703
|
+
return '\n'.join(prompt_parts)
|
|
1704
|
+
|
|
1705
|
+
async def _call_llm_for_query_generation(self, prompt: str) -> str:
|
|
1706
|
+
"""Call LLM client to generate SQL query."""
|
|
1707
|
+
system_msg = "You are a SQL expert. Generate precise SQL queries. Return only the SQL, no explanations."
|
|
1708
|
+
|
|
1709
|
+
if self.llm:
|
|
1710
|
+
response = await self.llm.ask(prompt, system_prompt=system_msg)
|
|
1711
|
+
if isinstance(response, AIMessage):
|
|
1712
|
+
return str(response.output).strip()
|
|
1713
|
+
# Handle possible dict response if client doesn't return AIMessage (fallback)
|
|
1714
|
+
elif isinstance(response, dict) and 'content' in response:
|
|
1715
|
+
# Should extract text
|
|
1716
|
+
return str(response['content']).strip() # Simplified fallback
|
|
1717
|
+
return str(response).strip()
|
|
1718
|
+
|
|
1719
|
+
if hasattr(self, 'agent') and self.agent:
|
|
1720
|
+
response = await self.agent.llm.acomplete(prompt, system_message=system_msg)
|
|
1721
|
+
return response.strip()
|
|
1722
|
+
|
|
1723
|
+
if hasattr(self, 'llm_client') and self.llm_client:
|
|
1724
|
+
response = await self.llm_client.acomplete(prompt)
|
|
1725
|
+
return response.strip()
|
|
1726
|
+
|
|
1727
|
+
raise ValueError(
|
|
1728
|
+
"No LLM client configured. Provide an 'llm', 'agent' or 'llm_client' to DatabaseTool."
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
async def _update_schema_knowledge_base(
|
|
1732
|
+
self,
|
|
1733
|
+
schema_metadata: SchemaMetadata
|
|
1734
|
+
) -> None:
|
|
1735
|
+
"""Update knowledge store with schema metadata for RAG."""
|
|
1736
|
+
if not self.knowledge_store:
|
|
1737
|
+
return
|
|
1738
|
+
|
|
1739
|
+
documents = []
|
|
1740
|
+
|
|
1741
|
+
for table in schema_metadata.tables:
|
|
1742
|
+
table_name = table.get('table_name')
|
|
1743
|
+
columns = table.get('columns', [])
|
|
1744
|
+
|
|
1745
|
+
column_descriptions = [
|
|
1746
|
+
f"{col.get('column_name')} ({col.get('data_type', 'unknown')})"
|
|
1747
|
+
for col in columns
|
|
1748
|
+
]
|
|
1749
|
+
|
|
1750
|
+
doc_text = f"""
|
|
1751
|
+
Table: {schema_metadata.schema_name}.{table_name}
|
|
1752
|
+
Database: {schema_metadata.database_flavor.value}
|
|
1753
|
+
Columns: {', '.join(column_descriptions)}
|
|
1754
|
+
"""
|
|
1755
|
+
|
|
1756
|
+
documents.append({
|
|
1757
|
+
'content': doc_text,
|
|
1758
|
+
'metadata': {
|
|
1759
|
+
'type': 'database_schema',
|
|
1760
|
+
'schema': schema_metadata.schema_name,
|
|
1761
|
+
'table': table_name,
|
|
1762
|
+
'database_flavor': schema_metadata.database_flavor.value,
|
|
1763
|
+
}
|
|
1764
|
+
})
|
|
1765
|
+
|
|
1766
|
+
await self.knowledge_store.add_documents(documents)
|
|
1767
|
+
|
|
1768
|
+
def _convert_to_structured_output(
|
|
1769
|
+
self,
|
|
1770
|
+
raw_results: Any,
|
|
1771
|
+
schema: Dict[str, Any]
|
|
1772
|
+
) -> List[Dict[str, Any]]:
|
|
1773
|
+
"""Convert raw query results to structured output."""
|
|
1774
|
+
if not raw_results:
|
|
1775
|
+
return []
|
|
1776
|
+
|
|
1777
|
+
field_mappings = schema.get('field_mappings', {})
|
|
1778
|
+
structured_results = []
|
|
1779
|
+
|
|
1780
|
+
for row in raw_results:
|
|
1781
|
+
structured_row = {}
|
|
1782
|
+
|
|
1783
|
+
if isinstance(row, dict):
|
|
1784
|
+
if field_mappings:
|
|
1785
|
+
for target, source in field_mappings.items():
|
|
1786
|
+
if source in row:
|
|
1787
|
+
structured_row[target] = row[source]
|
|
1788
|
+
elif target in row:
|
|
1789
|
+
structured_row[target] = row[target]
|
|
1790
|
+
else:
|
|
1791
|
+
structured_row = row
|
|
1792
|
+
else:
|
|
1793
|
+
fields = schema.get('fields', [])
|
|
1794
|
+
for i, field in enumerate(fields):
|
|
1795
|
+
if i < len(row):
|
|
1796
|
+
structured_row[field] = row[i]
|
|
1797
|
+
|
|
1798
|
+
structured_results.append(structured_row)
|
|
1799
|
+
|
|
1800
|
+
return structured_results
|