ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
parrot/bots/db/sql.py
ADDED
|
@@ -0,0 +1,1255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced SQL Database Agent Implementation for AI-Parrot.
|
|
3
|
+
|
|
4
|
+
Concrete implementation of AbstractDbAgent for SQL databases with support for:
|
|
5
|
+
- PostgreSQL, MySQL, and SQL Server
|
|
6
|
+
- Dictionary and string credentials
|
|
7
|
+
- Dual DSN generation for SQLAlchemy and asyncdb
|
|
8
|
+
- DatabaseQueryTool integration for query validation and execution
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Dict, Any, List, Optional, Union
|
|
12
|
+
import re
|
|
13
|
+
from urllib.parse import urlparse, quote_plus
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
17
|
+
from sqlalchemy.orm import sessionmaker
|
|
18
|
+
from sqlalchemy import text
|
|
19
|
+
from .abstract import AbstractDBAgent
|
|
20
|
+
from .tools import DatabaseSchema, TableMetadata
|
|
21
|
+
from ...models import AIMessage
|
|
22
|
+
from ...tools.databasequery import DatabaseQueryTool
|
|
23
|
+
from ...tools import ToolResult
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SQLAgent(AbstractDBAgent):
|
|
27
|
+
"""
|
|
28
|
+
SQL Database Agent with dual DSN support and DatabaseQueryTool integration.
|
|
29
|
+
|
|
30
|
+
Supports PostgreSQL, MySQL, and SQL Server with both dictionary and string credentials.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# Database flavor mappings for SQLAlchemy
|
|
34
|
+
SQLALCHEMY_DIALECT_MAPPING = {
|
|
35
|
+
'postgresql': 'postgresql+asyncpg',
|
|
36
|
+
'pg': 'postgresql+asyncpg',
|
|
37
|
+
'postgres': 'postgresql+asyncpg',
|
|
38
|
+
'mysql': 'mysql+aiomysql',
|
|
39
|
+
'sqlserver': 'mssql+aioodbc',
|
|
40
|
+
'mssql': 'mssql+aioodbc'
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Default ports for databases
|
|
44
|
+
DEFAULT_PORTS = {
|
|
45
|
+
'postgresql': 5432,
|
|
46
|
+
'postgres': 5432,
|
|
47
|
+
'mysql': 3306,
|
|
48
|
+
'sqlserver': 1433,
|
|
49
|
+
'mssql': 1433
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
name: str = "SQLAgent",
|
|
55
|
+
credentials: Union[str, Dict[str, Any]] = None,
|
|
56
|
+
database_flavor: str = "postgresql",
|
|
57
|
+
schema_name: str = "public",
|
|
58
|
+
max_sample_rows: int = 2,
|
|
59
|
+
**kwargs
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Initialize SQL Database Agent.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
name: Agent name
|
|
66
|
+
credentials: Connection credentials (dict or connection string)
|
|
67
|
+
database_flavor: Database type (postgresql, mysql, sqlserver)
|
|
68
|
+
schema_name: Target schema name
|
|
69
|
+
max_sample_rows: Maximum rows to sample from each table
|
|
70
|
+
"""
|
|
71
|
+
self.database_flavor = database_flavor.lower()
|
|
72
|
+
self.max_sample_rows = max_sample_rows
|
|
73
|
+
self.async_session_maker = None
|
|
74
|
+
|
|
75
|
+
# DSN strings for different purposes
|
|
76
|
+
self.discovery_dsn = None # SQLAlchemy format for schema discovery
|
|
77
|
+
self.dsn = None # asyncdb format for DatabaseQueryTool
|
|
78
|
+
self.credentials = None
|
|
79
|
+
self.connection_dict = None
|
|
80
|
+
if isinstance(credentials, dict):
|
|
81
|
+
self.connection_dict = credentials
|
|
82
|
+
|
|
83
|
+
# Validate database flavor
|
|
84
|
+
if self.database_flavor not in self.SQLALCHEMY_DIALECT_MAPPING:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Unsupported database flavor: {database_flavor}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Force low temperature to minimize hallucinations
|
|
90
|
+
kwargs['temperature'] = kwargs.get('temperature', 0.0)
|
|
91
|
+
|
|
92
|
+
super().__init__(
|
|
93
|
+
name=name,
|
|
94
|
+
credentials=credentials,
|
|
95
|
+
schema_name=schema_name,
|
|
96
|
+
**kwargs
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Process credentials and generate DSNs
|
|
100
|
+
self._process_credentials(credentials)
|
|
101
|
+
|
|
102
|
+
# Add SQL-specific tools
|
|
103
|
+
self._setup_sql_tools()
|
|
104
|
+
|
|
105
|
+
def _dsn_for_sqlalchemy(self, connection_string: str) -> str:
|
|
106
|
+
"""Adapt connection string for SQLAlchemy async drivers."""
|
|
107
|
+
parsed = urlparse(connection_string)
|
|
108
|
+
|
|
109
|
+
if parsed.scheme.startswith('postgresql') and '+asyncpg' not in parsed.scheme:
|
|
110
|
+
return connection_string.replace('postgresql://', 'postgresql+asyncpg://')
|
|
111
|
+
elif parsed.scheme.startswith('postgres') and '+asyncpg' not in parsed.scheme:
|
|
112
|
+
return connection_string.replace('postgres://', 'postgresql+asyncpg://')
|
|
113
|
+
elif parsed.scheme.startswith('mysql') and '+aiomysql' not in parsed.scheme:
|
|
114
|
+
return connection_string.replace('mysql://', 'mysql+aiomysql://')
|
|
115
|
+
elif parsed.scheme.startswith('mssql') and '+aioodbc' not in parsed.scheme:
|
|
116
|
+
return connection_string.replace('mssql://', 'mssql+aioodbc://')
|
|
117
|
+
|
|
118
|
+
return connection_string
|
|
119
|
+
|
|
120
|
+
def _dsn_for_asyncdb(self, connection_string: str) -> str:
|
|
121
|
+
"""Adapt connection string for asyncdb format."""
|
|
122
|
+
parsed = urlparse(connection_string)
|
|
123
|
+
|
|
124
|
+
# Check if already in asyncdb format:
|
|
125
|
+
if parsed.scheme in ['postgres', 'mysql', 'mssql']:
|
|
126
|
+
return connection_string
|
|
127
|
+
|
|
128
|
+
# Convert SQLAlchemy formats to asyncdb formats
|
|
129
|
+
if parsed.scheme.startswith('postgresql'):
|
|
130
|
+
return connection_string.replace(
|
|
131
|
+
'postgresql+asyncpg://', 'postgres://'
|
|
132
|
+
).replace('postgresql://', 'postgres://')
|
|
133
|
+
elif parsed.scheme.startswith('mysql'):
|
|
134
|
+
return connection_string.replace(
|
|
135
|
+
'mysql+aiomysql://', 'mysql://'
|
|
136
|
+
).replace('mysql://', 'mysql://')
|
|
137
|
+
elif parsed.scheme.startswith('mssql'):
|
|
138
|
+
return connection_string.replace(
|
|
139
|
+
'mssql+aioodbc://', 'mssql://'
|
|
140
|
+
).replace('mssql://', 'mssql://')
|
|
141
|
+
|
|
142
|
+
return connection_string
|
|
143
|
+
|
|
144
|
+
def _process_credentials(self, credentials: Union[str, Dict[str, Any]]) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Process credentials and generate both discovery_dsn and dsn.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
credentials: Either connection string or dictionary with connection params
|
|
150
|
+
"""
|
|
151
|
+
if isinstance(credentials, str):
|
|
152
|
+
# Connection string provided
|
|
153
|
+
self.connection_string = credentials
|
|
154
|
+
self.discovery_dsn = self._dsn_for_sqlalchemy(credentials)
|
|
155
|
+
self.dsn = self._dsn_for_asyncdb(credentials)
|
|
156
|
+
self.credentials = {}
|
|
157
|
+
elif isinstance(credentials, dict):
|
|
158
|
+
# Dictionary credentials provided
|
|
159
|
+
self.connection_dict = credentials
|
|
160
|
+
self.discovery_dsn = self._build_sqlalchemy_dsn_from_dict(credentials)
|
|
161
|
+
self.dsn = self._build_asyncdb_dsn_from_dict(credentials)
|
|
162
|
+
self.connection_string = self.discovery_dsn
|
|
163
|
+
self.credentials = credentials
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Credentials must be either a connection string or dictionary"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def _build_sqlalchemy_dsn_from_dict(self, creds: Dict[str, Any]) -> str:
|
|
170
|
+
"""
|
|
171
|
+
Build SQLAlchemy DSN from credentials dictionary.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
creds: Dictionary with keys like host, port, database, username, password
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
SQLAlchemy-compatible connection string
|
|
178
|
+
"""
|
|
179
|
+
# Extract credentials with defaults
|
|
180
|
+
host = creds.get('host', 'localhost')
|
|
181
|
+
port = creds.get('port', self.DEFAULT_PORTS.get(self.database_flavor, 5432))
|
|
182
|
+
database = creds.get('database', creds.get('dbname', 'postgres'))
|
|
183
|
+
username = creds.get('username', creds.get('user', 'postgres'))
|
|
184
|
+
password = creds.get('password', creds.get('pwd', ''))
|
|
185
|
+
|
|
186
|
+
# URL encode password to handle special characters
|
|
187
|
+
encoded_password = quote_plus(str(password)) if password else ''
|
|
188
|
+
|
|
189
|
+
# Get SQLAlchemy dialect
|
|
190
|
+
dialect = self.SQLALCHEMY_DIALECT_MAPPING[self.database_flavor]
|
|
191
|
+
|
|
192
|
+
# Build connection string
|
|
193
|
+
if encoded_password:
|
|
194
|
+
dsn = f"{dialect}://{username}:{encoded_password}@{host}:{port}/{database}"
|
|
195
|
+
else:
|
|
196
|
+
dsn = f"{dialect}://{username}@{host}:{port}/{database}"
|
|
197
|
+
|
|
198
|
+
# Add any additional parameters
|
|
199
|
+
params = []
|
|
200
|
+
for key, value in creds.items():
|
|
201
|
+
if key not in ['host', 'port', 'database', 'dbname', 'username', 'user', 'password', 'pwd']:
|
|
202
|
+
params.append(f"{key}={value}")
|
|
203
|
+
|
|
204
|
+
if params:
|
|
205
|
+
dsn += "?" + "&".join(params)
|
|
206
|
+
|
|
207
|
+
return dsn
|
|
208
|
+
|
|
209
|
+
def _build_asyncdb_dsn_from_dict(self, creds: Dict[str, Any]) -> str:
|
|
210
|
+
"""
|
|
211
|
+
Build asyncdb DSN from credentials dictionary.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
creds: Dictionary with connection parameters
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
asyncdb-compatible connection string (postgres://...)
|
|
218
|
+
"""
|
|
219
|
+
# Extract credentials
|
|
220
|
+
host = creds.get('host', 'localhost')
|
|
221
|
+
port = creds.get('port', self.DEFAULT_PORTS.get(self.database_flavor, 5432))
|
|
222
|
+
database = creds.get('database', creds.get('dbname', 'postgres'))
|
|
223
|
+
username = creds.get('username', creds.get('user', 'postgres'))
|
|
224
|
+
password = creds.get('password', creds.get('pwd', ''))
|
|
225
|
+
|
|
226
|
+
# URL encode password
|
|
227
|
+
encoded_password = quote_plus(str(password)) if password else ''
|
|
228
|
+
|
|
229
|
+
# Get asyncdb scheme (postgres for PostgreSQL regardless of flavor name)
|
|
230
|
+
if self.database_flavor in ['postgresql', 'postgres']:
|
|
231
|
+
scheme = 'postgres'
|
|
232
|
+
elif self.database_flavor == 'mysql':
|
|
233
|
+
scheme = 'mysql'
|
|
234
|
+
elif self.database_flavor in ['sqlserver', 'mssql']:
|
|
235
|
+
scheme = 'mssql'
|
|
236
|
+
else:
|
|
237
|
+
scheme = 'postgres' # Default fallback
|
|
238
|
+
|
|
239
|
+
# Build DSN
|
|
240
|
+
if encoded_password:
|
|
241
|
+
dsn = f"{scheme}://{username}:{encoded_password}@{host}:{port}/{database}"
|
|
242
|
+
else:
|
|
243
|
+
dsn = f"{scheme}://{username}@{host}:{port}/{database}"
|
|
244
|
+
|
|
245
|
+
return dsn
|
|
246
|
+
|
|
247
|
+
def _setup_sql_tools(self):
|
|
248
|
+
"""Setup SQL-specific tools including DatabaseQueryTool."""
|
|
249
|
+
# The DatabaseQueryTool should already be registered in the parent class
|
|
250
|
+
# We just need to ensure it's configured properly
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
async def connect_database(self) -> None:
|
|
254
|
+
"""Connect to the SQL database using SQLAlchemy async engine."""
|
|
255
|
+
if not self.discovery_dsn:
|
|
256
|
+
raise ValueError("Discovery DSN is required")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
# Create async engine for schema discovery
|
|
260
|
+
self.engine = create_async_engine(
|
|
261
|
+
self.discovery_dsn,
|
|
262
|
+
echo=False,
|
|
263
|
+
pool_pre_ping=True,
|
|
264
|
+
pool_recycle=3600
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Create session maker
|
|
268
|
+
self.async_session_maker = sessionmaker(
|
|
269
|
+
self.engine,
|
|
270
|
+
class_=AsyncSession,
|
|
271
|
+
expire_on_commit=False
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Test connection
|
|
275
|
+
async with self.engine.begin() as conn:
|
|
276
|
+
await conn.execute(text("SELECT 1"))
|
|
277
|
+
|
|
278
|
+
self.logger.info(
|
|
279
|
+
f"Successfully connected to {self.database_flavor} database using SQLAlchemy"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Test DatabaseQueryTool connection
|
|
283
|
+
await self._test_database_query_tool()
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
self.logger.error(f"Failed to connect to database: {e}")
|
|
287
|
+
raise
|
|
288
|
+
|
|
289
|
+
async def _test_database_query_tool(self) -> None:
|
|
290
|
+
"""Test DatabaseQueryTool connection."""
|
|
291
|
+
try:
|
|
292
|
+
# Get database query tool from registered tools
|
|
293
|
+
db_tool = self.tool_manager.get_tool('database_query')
|
|
294
|
+
if db_tool:
|
|
295
|
+
# Test with a simple query
|
|
296
|
+
test_result = await db_tool.execute(
|
|
297
|
+
driver='pg' if self.database_flavor in ['postgresql', 'postgres', 'pg'] else self.database_flavor,
|
|
298
|
+
query="SELECT 1 as test_column LIMIT 1",
|
|
299
|
+
dsn=self.dsn,
|
|
300
|
+
credentials=self.credentials or None,
|
|
301
|
+
output_format='native'
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if test_result.status == "success":
|
|
305
|
+
self.logger.debug(
|
|
306
|
+
"DatabaseQueryTool connection test successful"
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
self.logger.warning(
|
|
310
|
+
f"DatabaseQueryTool test failed: {test_result.error}"
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
self.logger.warning(
|
|
314
|
+
"DatabaseQueryTool not found in registered tools"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
except Exception as e:
|
|
318
|
+
self.logger.warning(
|
|
319
|
+
f"DatabaseQueryTool test failed: {e}"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
async def extract_schema_metadata(self) -> DatabaseSchema:
|
|
323
|
+
"""Extract complete schema metadata from SQL database."""
|
|
324
|
+
if not self.engine:
|
|
325
|
+
await self.connect_database()
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
async with self.engine.begin() as conn:
|
|
329
|
+
# Get database name
|
|
330
|
+
db_name_query = await self._get_database_name_query()
|
|
331
|
+
result = await conn.execute(text(db_name_query))
|
|
332
|
+
database_name = result.scalar()
|
|
333
|
+
|
|
334
|
+
# Extract tables metadata
|
|
335
|
+
tables = await self._extract_tables_metadata(conn)
|
|
336
|
+
|
|
337
|
+
# Extract views metadata (simplified for now)
|
|
338
|
+
views = []
|
|
339
|
+
|
|
340
|
+
schema_metadata = DatabaseSchema(
|
|
341
|
+
database_name=database_name or "unknown",
|
|
342
|
+
database_type=self.database_flavor,
|
|
343
|
+
tables=tables,
|
|
344
|
+
views=views,
|
|
345
|
+
functions=[],
|
|
346
|
+
procedures=[],
|
|
347
|
+
metadata={
|
|
348
|
+
"schema_name": self.schema_name,
|
|
349
|
+
"extraction_timestamp": datetime.now().isoformat(),
|
|
350
|
+
"total_tables": len(tables),
|
|
351
|
+
"total_views": len(views),
|
|
352
|
+
"discovery_dsn": self.discovery_dsn,
|
|
353
|
+
"asyncdb_dsn": self.dsn
|
|
354
|
+
}
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
self.logger.info(
|
|
358
|
+
f"Extracted metadata for {len(tables)} tables"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return schema_metadata
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
self.logger.error(f"Failed to extract schema metadata: {e}")
|
|
365
|
+
raise
|
|
366
|
+
|
|
367
|
+
async def _get_database_name_query(self) -> str:
|
|
368
|
+
"""Get database name query based on database flavor."""
|
|
369
|
+
if self.database_flavor in ['postgresql', 'postgres']:
|
|
370
|
+
return "SELECT current_database()"
|
|
371
|
+
elif self.database_flavor == 'mysql':
|
|
372
|
+
return "SELECT database()"
|
|
373
|
+
elif self.database_flavor in ['sqlserver', 'mssql']:
|
|
374
|
+
return "SELECT DB_NAME()"
|
|
375
|
+
else:
|
|
376
|
+
return "SELECT 'unknown' as database_name"
|
|
377
|
+
|
|
378
|
+
async def _extract_tables_metadata(self, conn) -> List[TableMetadata]:
|
|
379
|
+
"""Extract metadata for all tables in the schema."""
|
|
380
|
+
tables = []
|
|
381
|
+
|
|
382
|
+
# Get table names
|
|
383
|
+
if self.database_flavor in ['postgresql', 'postgres']:
|
|
384
|
+
table_query = """
|
|
385
|
+
SELECT table_name
|
|
386
|
+
FROM information_schema.tables
|
|
387
|
+
WHERE table_schema = :schema_name
|
|
388
|
+
AND table_type = 'BASE TABLE'
|
|
389
|
+
ORDER BY table_name
|
|
390
|
+
"""
|
|
391
|
+
elif self.database_flavor == 'mysql':
|
|
392
|
+
table_query = """
|
|
393
|
+
SELECT table_name
|
|
394
|
+
FROM information_schema.tables
|
|
395
|
+
WHERE table_schema = :schema_name
|
|
396
|
+
AND table_type = 'BASE TABLE'
|
|
397
|
+
ORDER BY table_name
|
|
398
|
+
"""
|
|
399
|
+
else: # SQL Server
|
|
400
|
+
table_query = """
|
|
401
|
+
SELECT table_name
|
|
402
|
+
FROM information_schema.tables
|
|
403
|
+
WHERE table_schema = :schema_name
|
|
404
|
+
AND table_type = 'BASE TABLE'
|
|
405
|
+
ORDER BY table_name
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
result = await conn.execute(
|
|
409
|
+
text(table_query), {"schema_name": self.schema_name}
|
|
410
|
+
)
|
|
411
|
+
table_rows = result.fetchall()
|
|
412
|
+
|
|
413
|
+
for row in table_rows:
|
|
414
|
+
table_name = row[0]
|
|
415
|
+
# Extract detailed table metadata
|
|
416
|
+
table_metadata = await self._extract_single_table_metadata(conn, table_name)
|
|
417
|
+
tables.append(table_metadata)
|
|
418
|
+
|
|
419
|
+
return tables
|
|
420
|
+
|
|
421
|
+
async def _extract_single_table_metadata(self, conn, table_name: str) -> TableMetadata:
|
|
422
|
+
"""Extract detailed metadata for a single table."""
|
|
423
|
+
# Get column information
|
|
424
|
+
columns = await self._get_table_columns(conn, table_name)
|
|
425
|
+
|
|
426
|
+
# Get primary keys
|
|
427
|
+
primary_keys = await self._get_primary_keys(conn, table_name)
|
|
428
|
+
|
|
429
|
+
# Get foreign keys
|
|
430
|
+
foreign_keys = await self._get_foreign_keys(conn, table_name)
|
|
431
|
+
|
|
432
|
+
# Get sample data using DatabaseQueryTool
|
|
433
|
+
sample_data = await self._get_sample_data_via_tool(table_name)
|
|
434
|
+
|
|
435
|
+
return TableMetadata(
|
|
436
|
+
name=table_name,
|
|
437
|
+
schema=self.schema_name,
|
|
438
|
+
columns=columns,
|
|
439
|
+
primary_keys=primary_keys,
|
|
440
|
+
foreign_keys=foreign_keys,
|
|
441
|
+
indexes=[], # Simplified for now
|
|
442
|
+
description=None, # Simplified for now
|
|
443
|
+
sample_data=sample_data
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
async def _get_table_columns(self, conn, table_name: str) -> List[Dict[str, Any]]:
|
|
447
|
+
"""Get column information for a table."""
|
|
448
|
+
if self.database_flavor in ['postgresql', 'postgres']:
|
|
449
|
+
query = """
|
|
450
|
+
SELECT
|
|
451
|
+
column_name,
|
|
452
|
+
data_type,
|
|
453
|
+
is_nullable,
|
|
454
|
+
column_default,
|
|
455
|
+
character_maximum_length,
|
|
456
|
+
numeric_precision,
|
|
457
|
+
numeric_scale
|
|
458
|
+
FROM information_schema.columns
|
|
459
|
+
WHERE table_schema = :schema_name
|
|
460
|
+
AND table_name = :table_name
|
|
461
|
+
ORDER BY ordinal_position
|
|
462
|
+
"""
|
|
463
|
+
elif self.database_flavor == 'mysql':
|
|
464
|
+
query = """
|
|
465
|
+
SELECT
|
|
466
|
+
column_name,
|
|
467
|
+
data_type,
|
|
468
|
+
is_nullable,
|
|
469
|
+
column_default,
|
|
470
|
+
character_maximum_length,
|
|
471
|
+
numeric_precision,
|
|
472
|
+
numeric_scale
|
|
473
|
+
FROM information_schema.columns
|
|
474
|
+
WHERE table_schema = :schema_name
|
|
475
|
+
AND table_name = :table_name
|
|
476
|
+
ORDER BY ordinal_position
|
|
477
|
+
"""
|
|
478
|
+
else: # SQL Server
|
|
479
|
+
query = """
|
|
480
|
+
SELECT
|
|
481
|
+
column_name,
|
|
482
|
+
data_type,
|
|
483
|
+
is_nullable,
|
|
484
|
+
column_default,
|
|
485
|
+
character_maximum_length,
|
|
486
|
+
numeric_precision,
|
|
487
|
+
numeric_scale
|
|
488
|
+
FROM information_schema.columns
|
|
489
|
+
WHERE table_schema = :schema_name
|
|
490
|
+
AND table_name = :table_name
|
|
491
|
+
ORDER BY ordinal_position
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
result = await conn.execute(text(query), {
|
|
495
|
+
"schema_name": self.schema_name,
|
|
496
|
+
"table_name": table_name
|
|
497
|
+
})
|
|
498
|
+
|
|
499
|
+
columns = []
|
|
500
|
+
for row in result.fetchall():
|
|
501
|
+
columns.append({
|
|
502
|
+
"name": row[0],
|
|
503
|
+
"type": row[1],
|
|
504
|
+
"nullable": row[2] == "YES",
|
|
505
|
+
"default": row[3],
|
|
506
|
+
"max_length": row[4],
|
|
507
|
+
"precision": row[5],
|
|
508
|
+
"scale": row[6]
|
|
509
|
+
})
|
|
510
|
+
|
|
511
|
+
return columns
|
|
512
|
+
|
|
513
|
+
async def _get_primary_keys(self, conn, table_name: str) -> List[str]:
|
|
514
|
+
"""Get primary key columns for a table."""
|
|
515
|
+
if self.database_flavor in ['postgresql', 'postgres']:
|
|
516
|
+
query = """
|
|
517
|
+
SELECT column_name
|
|
518
|
+
FROM information_schema.key_column_usage
|
|
519
|
+
WHERE table_schema = :schema_name
|
|
520
|
+
AND table_name = :table_name
|
|
521
|
+
AND constraint_name IN (
|
|
522
|
+
SELECT constraint_name
|
|
523
|
+
FROM information_schema.table_constraints
|
|
524
|
+
WHERE table_schema = :schema_name
|
|
525
|
+
AND table_name = :table_name
|
|
526
|
+
AND constraint_type = 'PRIMARY KEY'
|
|
527
|
+
)
|
|
528
|
+
ORDER BY ordinal_position
|
|
529
|
+
"""
|
|
530
|
+
else: # MySQL and SQL Server
|
|
531
|
+
query = """
|
|
532
|
+
SELECT column_name
|
|
533
|
+
FROM information_schema.key_column_usage
|
|
534
|
+
WHERE table_schema = :schema_name
|
|
535
|
+
AND table_name = :table_name
|
|
536
|
+
AND constraint_name = 'PRIMARY'
|
|
537
|
+
ORDER BY ordinal_position
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
result = await conn.execute(text(query), {
|
|
541
|
+
"schema_name": self.schema_name,
|
|
542
|
+
"table_name": table_name
|
|
543
|
+
})
|
|
544
|
+
|
|
545
|
+
return [row[0] for row in result.fetchall()]
|
|
546
|
+
|
|
547
|
+
async def _get_foreign_keys(self, conn, table_name: str) -> List[Dict[str, Any]]:
|
|
548
|
+
"""Get foreign key information for a table."""
|
|
549
|
+
query = """
|
|
550
|
+
SELECT
|
|
551
|
+
kcu.column_name,
|
|
552
|
+
ccu.table_schema AS referenced_table_schema,
|
|
553
|
+
ccu.table_name AS referenced_table_name,
|
|
554
|
+
ccu.column_name AS referenced_column_name
|
|
555
|
+
FROM information_schema.key_column_usage kcu
|
|
556
|
+
JOIN information_schema.constraint_column_usage ccu
|
|
557
|
+
ON kcu.constraint_name = ccu.constraint_name
|
|
558
|
+
WHERE kcu.table_schema = :schema_name
|
|
559
|
+
AND kcu.table_name = :table_name
|
|
560
|
+
AND kcu.constraint_name IN (
|
|
561
|
+
SELECT constraint_name
|
|
562
|
+
FROM information_schema.table_constraints
|
|
563
|
+
WHERE table_schema = :schema_name
|
|
564
|
+
AND table_name = :table_name
|
|
565
|
+
AND constraint_type = 'FOREIGN KEY'
|
|
566
|
+
)
|
|
567
|
+
"""
|
|
568
|
+
|
|
569
|
+
result = await conn.execute(text(query), {
|
|
570
|
+
"schema_name": self.schema_name,
|
|
571
|
+
"table_name": table_name
|
|
572
|
+
})
|
|
573
|
+
|
|
574
|
+
foreign_keys = []
|
|
575
|
+
for row in result.fetchall():
|
|
576
|
+
foreign_keys.append({
|
|
577
|
+
"column": row[0],
|
|
578
|
+
"referenced_table_schema": row[1],
|
|
579
|
+
"referenced_table": row[2],
|
|
580
|
+
"referenced_column": row[3]
|
|
581
|
+
})
|
|
582
|
+
|
|
583
|
+
return foreign_keys
|
|
584
|
+
|
|
585
|
+
async def _get_sample_data_via_tool(self, table_name: str) -> List[Dict[str, Any]]:
|
|
586
|
+
"""Get sample data using DatabaseQueryTool."""
|
|
587
|
+
try:
|
|
588
|
+
# Get database query tool
|
|
589
|
+
db_tool = self.tool_manager.get_tool('database_query')
|
|
590
|
+
if not db_tool:
|
|
591
|
+
self.logger.warning("DatabaseQueryTool not found")
|
|
592
|
+
return []
|
|
593
|
+
|
|
594
|
+
# Build sample query
|
|
595
|
+
full_table_name = f'"{self.schema_name}"."{table_name}"' if self.schema_name != 'public' else f'"{table_name}"'
|
|
596
|
+
sample_query = f"SELECT * FROM {full_table_name} LIMIT {self.max_sample_rows}"
|
|
597
|
+
|
|
598
|
+
# Execute query
|
|
599
|
+
result = await db_tool.execute(
|
|
600
|
+
driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
|
|
601
|
+
query=sample_query,
|
|
602
|
+
dsn=self.dsn,
|
|
603
|
+
credentials=self.connection_dict,
|
|
604
|
+
output_format='json'
|
|
605
|
+
)
|
|
606
|
+
if result.status == "success":
|
|
607
|
+
return result.result
|
|
608
|
+
else:
|
|
609
|
+
self.logger.warning(f"Could not get sample data for {table_name}: {result.error}")
|
|
610
|
+
return []
|
|
611
|
+
|
|
612
|
+
except Exception as e:
|
|
613
|
+
self.logger.warning(f"Error getting sample data for {table_name}: {e}")
|
|
614
|
+
return []
|
|
615
|
+
|
|
616
|
+
async def generate_query(
|
|
617
|
+
self,
|
|
618
|
+
natural_language_query: str,
|
|
619
|
+
target_tables: Optional[List[str]] = None,
|
|
620
|
+
query_type: str = "SELECT"
|
|
621
|
+
) -> Dict[str, Any]:
|
|
622
|
+
"""Generate SQL query from natural language and validate it."""
|
|
623
|
+
try:
|
|
624
|
+
# Get schema context
|
|
625
|
+
schema_context = await self._get_schema_context_for_query(
|
|
626
|
+
natural_language_query, target_tables
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Build prompt for LLM
|
|
630
|
+
prompt = self._build_query_generation_prompt(
|
|
631
|
+
natural_language_query=natural_language_query,
|
|
632
|
+
schema_context=schema_context,
|
|
633
|
+
query_type=query_type,
|
|
634
|
+
database_flavor=self.database_flavor
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Generate query using LLM
|
|
638
|
+
response = await self._llm.ask(
|
|
639
|
+
prompt=prompt,
|
|
640
|
+
model=self._llm_model,
|
|
641
|
+
temperature=0.0, # Zero temperature for deterministic results
|
|
642
|
+
use_tools=False, # Explicitly disable tools to prevent recursion
|
|
643
|
+
tools=[]
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# Extract SQL query from response
|
|
647
|
+
generated_query = self._extract_sql_from_response(str(response.output))
|
|
648
|
+
|
|
649
|
+
# Validate query using DatabaseQueryTool with LIMIT 0
|
|
650
|
+
validation_result = await self._validate_query_with_tool(generated_query)
|
|
651
|
+
|
|
652
|
+
result = {
|
|
653
|
+
"query": generated_query,
|
|
654
|
+
"query_type": query_type,
|
|
655
|
+
"tables_used": self._extract_tables_from_query(generated_query),
|
|
656
|
+
"schema_context_used": len(schema_context),
|
|
657
|
+
"validation": validation_result,
|
|
658
|
+
"natural_language_input": natural_language_query
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
return result
|
|
662
|
+
|
|
663
|
+
except Exception as e:
|
|
664
|
+
self.logger.error(f"Failed to generate query: {e}")
|
|
665
|
+
raise
|
|
666
|
+
|
|
667
|
+
async def _validate_query_with_tool(self, query: str) -> Dict[str, Any]:
|
|
668
|
+
"""Validate query using DatabaseQueryTool with LIMIT 0."""
|
|
669
|
+
try:
|
|
670
|
+
# Get database query tool
|
|
671
|
+
db_tool = None
|
|
672
|
+
for tool in self.tools:
|
|
673
|
+
if isinstance(tool, DatabaseQueryTool):
|
|
674
|
+
db_tool = tool
|
|
675
|
+
break
|
|
676
|
+
|
|
677
|
+
if not db_tool:
|
|
678
|
+
return {
|
|
679
|
+
"valid": False,
|
|
680
|
+
"error": "DatabaseQueryTool not available",
|
|
681
|
+
"method": "tool_validation"
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
# Modify query to add LIMIT 0 for validation (no data returned)
|
|
685
|
+
if query.strip().upper().startswith('SELECT'):
|
|
686
|
+
validation_query = f"SELECT * FROM ({query.rstrip(';')}) AS validation_subquery LIMIT 0"
|
|
687
|
+
else:
|
|
688
|
+
# For non-SELECT queries, we can't easily validate without risk
|
|
689
|
+
validation_query = query
|
|
690
|
+
|
|
691
|
+
# Execute validation query
|
|
692
|
+
result = await db_tool.execute(
|
|
693
|
+
driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
|
|
694
|
+
query=validation_query,
|
|
695
|
+
dsn=self.dsn,
|
|
696
|
+
credentials=self.connection_dict,
|
|
697
|
+
output_format='native'
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
return {
|
|
701
|
+
"valid": result.status == "success",
|
|
702
|
+
"error": result.error if result.status == "error" else None,
|
|
703
|
+
"method": "database_query_tool",
|
|
704
|
+
"validation_query": validation_query
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
except Exception as e:
|
|
708
|
+
return {
|
|
709
|
+
"valid": False,
|
|
710
|
+
"error": str(e),
|
|
711
|
+
"method": "tool_validation"
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
async def explain_query(self, query: str) -> str:
|
|
715
|
+
"""
|
|
716
|
+
Explain a database query (e.g. EXPLAIN ANALYZE).
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
query: The SQL query to explain
|
|
720
|
+
|
|
721
|
+
Returns:
|
|
722
|
+
The execution plan as a string
|
|
723
|
+
"""
|
|
724
|
+
try:
|
|
725
|
+
# Construct EXPLAIN query based on flavor
|
|
726
|
+
if self.database_flavor in ['postgresql', 'postgres', 'pg']:
|
|
727
|
+
# Use JSON format for better parsing if needed, and ANALYZE for actual execution stats
|
|
728
|
+
explain_query = f"EXPLAIN (FORMAT JSON, ANALYZE) {query}"
|
|
729
|
+
elif self.database_flavor == 'mysql':
|
|
730
|
+
explain_query = f"EXPLAIN ANALYZE {query}"
|
|
731
|
+
else:
|
|
732
|
+
explain_query = f"EXPLAIN {query}"
|
|
733
|
+
|
|
734
|
+
# Execute the explain query
|
|
735
|
+
# We use execute_query but need to handle the result format
|
|
736
|
+
result = await self.execute_query(explain_query, limit=0) # limit=0 is ignored for EXPLAIN usually
|
|
737
|
+
|
|
738
|
+
if result["success"]:
|
|
739
|
+
# Format the result
|
|
740
|
+
data = result["data"]
|
|
741
|
+
if self.database_flavor in ['postgresql', 'postgres', 'pg']:
|
|
742
|
+
# Postgres JSON output usually comes as a single cell with lists
|
|
743
|
+
try:
|
|
744
|
+
# It might be a list of dicts in the first column
|
|
745
|
+
plan = data.iloc[0, 0]
|
|
746
|
+
if isinstance(plan, list) or isinstance(plan, dict):
|
|
747
|
+
return json.dumps(plan, indent=2)
|
|
748
|
+
return str(plan)
|
|
749
|
+
except Exception:
|
|
750
|
+
return data.to_string()
|
|
751
|
+
else:
|
|
752
|
+
return data.to_string()
|
|
753
|
+
else:
|
|
754
|
+
return f"Failed to explain query: {result['error']}"
|
|
755
|
+
|
|
756
|
+
except Exception as e:
|
|
757
|
+
self.logger.error(f"Error explaining query: {e}")
|
|
758
|
+
return f"Error explaining query: {str(e)}"
|
|
759
|
+
|
|
760
|
+
async def execute_query(self, query: str, limit: int = 200) -> Dict[str, Any]:
|
|
761
|
+
"""Execute SQL query and return results using DatabaseQueryTool."""
|
|
762
|
+
try:
|
|
763
|
+
# Get database query tool
|
|
764
|
+
db_tool = self.tool_manager.get_tool('database_query')
|
|
765
|
+
if not db_tool:
|
|
766
|
+
db_tool = None
|
|
767
|
+
|
|
768
|
+
if not db_tool:
|
|
769
|
+
return {
|
|
770
|
+
"success": False,
|
|
771
|
+
"error": "DatabaseQueryTool not available",
|
|
772
|
+
"query": query
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
# Add limit for SELECT queries if not present
|
|
776
|
+
execution_query = query
|
|
777
|
+
result = None
|
|
778
|
+
if query.strip().upper().startswith('SELECT') and 'LIMIT' not in query.upper():
|
|
779
|
+
execution_query = f"{query.rstrip(';')} LIMIT {limit}"
|
|
780
|
+
|
|
781
|
+
# Execute query (return a ToolResult)
|
|
782
|
+
result = await db_tool.execute(
|
|
783
|
+
driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
|
|
784
|
+
query=execution_query,
|
|
785
|
+
dsn=self.dsn,
|
|
786
|
+
credentials=self.connection_dict,
|
|
787
|
+
output_format='pandas'
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
if result.status == "success":
|
|
791
|
+
data = result.result
|
|
792
|
+
columns = data.columns.tolist() if not data.empty else []
|
|
793
|
+
row_count = len(data) if not data.empty else 0
|
|
794
|
+
return {
|
|
795
|
+
"success": True,
|
|
796
|
+
"data": data,
|
|
797
|
+
"columns": columns,
|
|
798
|
+
"row_count": row_count,
|
|
799
|
+
"query": execution_query,
|
|
800
|
+
"tool_used": "DatabaseQueryTool",
|
|
801
|
+
"raw_result": result
|
|
802
|
+
}
|
|
803
|
+
else:
|
|
804
|
+
return {
|
|
805
|
+
"success": False,
|
|
806
|
+
"error": result.error,
|
|
807
|
+
"query": execution_query,
|
|
808
|
+
"tool_used": "DatabaseQueryTool",
|
|
809
|
+
"raw_result": result
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
except Exception as e:
|
|
813
|
+
self.logger.error(f"Query execution failed: {e}")
|
|
814
|
+
return {
|
|
815
|
+
"success": False,
|
|
816
|
+
"error": str(e),
|
|
817
|
+
"query": query,
|
|
818
|
+
"tool_used": "DatabaseQueryTool",
|
|
819
|
+
"raw_result": None
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
async def _get_schema_context_for_query(
|
|
823
|
+
self,
|
|
824
|
+
natural_language_query: str,
|
|
825
|
+
target_tables: Optional[List[str]] = None
|
|
826
|
+
) -> List[Dict[str, Any]]:
|
|
827
|
+
"""Get relevant schema context for query generation."""
|
|
828
|
+
if target_tables:
|
|
829
|
+
context = []
|
|
830
|
+
for table_name in target_tables:
|
|
831
|
+
table_info = await self.search_schema(
|
|
832
|
+
search_term=table_name,
|
|
833
|
+
search_type="tables",
|
|
834
|
+
limit=1
|
|
835
|
+
)
|
|
836
|
+
if table_info:
|
|
837
|
+
context.extend(table_info)
|
|
838
|
+
return context
|
|
839
|
+
else:
|
|
840
|
+
return await self.search_schema(
|
|
841
|
+
search_term=natural_language_query,
|
|
842
|
+
search_type="all",
|
|
843
|
+
limit=5
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
def _build_query_generation_prompt(
|
|
847
|
+
self,
|
|
848
|
+
natural_language_query: str,
|
|
849
|
+
schema_context: List[Dict[str, Any]],
|
|
850
|
+
query_type: str,
|
|
851
|
+
database_flavor: str
|
|
852
|
+
) -> str:
|
|
853
|
+
"""Build prompt for LLM query generation."""
|
|
854
|
+
prompt = f"""
|
|
855
|
+
You are an expert SQL developer working with a {database_flavor} database.
|
|
856
|
+
Generate a clean, efficient {query_type} SQL query based on the natural language request and schema information.
|
|
857
|
+
|
|
858
|
+
Natural Language Request: {natural_language_query}
|
|
859
|
+
|
|
860
|
+
Available Schema Information:
|
|
861
|
+
"""
|
|
862
|
+
|
|
863
|
+
for i, context in enumerate(schema_context[:3], 1):
|
|
864
|
+
prompt += f"\n{i}. {context.get('content', '')}\n"
|
|
865
|
+
|
|
866
|
+
prompt += f"""
|
|
867
|
+
Requirements:
|
|
868
|
+
1. Generate valid {database_flavor} SQL with clean formatting
|
|
869
|
+
2. Use appropriate {database_flavor} syntax and functions
|
|
870
|
+
3. Use simple column names unless JOINs require qualification
|
|
871
|
+
4. Use table aliases for readability in JOINs
|
|
872
|
+
5. Only use double quotes for identifiers with special characters
|
|
873
|
+
6. Include appropriate WHERE clauses and filters
|
|
874
|
+
7. Optimize for performance and readability
|
|
875
|
+
8. Return ONLY the SQL query without explanations or formatting
|
|
876
|
+
|
|
877
|
+
Query Type: {query_type}
|
|
878
|
+
Database: {database_flavor}
|
|
879
|
+
|
|
880
|
+
SQL Query:"""
|
|
881
|
+
|
|
882
|
+
return prompt
|
|
883
|
+
|
|
884
|
+
def _extract_sql_from_response(self, response_text: str) -> str:
|
|
885
|
+
"""Extract SQL query from LLM response."""
|
|
886
|
+
# Remove markdown code blocks if present
|
|
887
|
+
if "```sql" in response_text:
|
|
888
|
+
lines = response_text.split('\n')
|
|
889
|
+
sql_lines = []
|
|
890
|
+
in_sql_block = False
|
|
891
|
+
|
|
892
|
+
for line in lines:
|
|
893
|
+
if line.strip().startswith("```sql"):
|
|
894
|
+
in_sql_block = True
|
|
895
|
+
continue
|
|
896
|
+
elif line.strip() == "```" and in_sql_block:
|
|
897
|
+
break
|
|
898
|
+
elif in_sql_block:
|
|
899
|
+
sql_lines.append(line)
|
|
900
|
+
|
|
901
|
+
return '\n'.join(sql_lines).strip()
|
|
902
|
+
else:
|
|
903
|
+
return response_text.strip()
|
|
904
|
+
|
|
905
|
+
async def ask(
|
|
906
|
+
self,
|
|
907
|
+
question: str = None,
|
|
908
|
+
user_context: str = "",
|
|
909
|
+
context: str = "",
|
|
910
|
+
return_results: bool = True, # New parameter to control query execution
|
|
911
|
+
session_id: Optional[str] = None,
|
|
912
|
+
user_id: Optional[str] = None,
|
|
913
|
+
use_conversation_history: bool = True,
|
|
914
|
+
**kwargs
|
|
915
|
+
) -> AIMessage:
|
|
916
|
+
"""
|
|
917
|
+
Enhanced ask method that can automatically execute generated SQL queries.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
question: The user's question about the database
|
|
921
|
+
user_context: User-specific context for database interaction
|
|
922
|
+
context: Additional context about data location, schema guidance
|
|
923
|
+
return_results: If True, automatically execute generated SQL queries and return data
|
|
924
|
+
session_id: Session identifier for conversation history
|
|
925
|
+
user_id: User identifier
|
|
926
|
+
use_conversation_history: Whether to use conversation history
|
|
927
|
+
**kwargs: Additional arguments for LLM
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
AIMessage: The response from the LLM, potentially enhanced with query results
|
|
931
|
+
"""
|
|
932
|
+
# Backwards compatibility
|
|
933
|
+
if question is None:
|
|
934
|
+
question = kwargs.get('prompt')
|
|
935
|
+
|
|
936
|
+
# First, get the standard response from the parent method
|
|
937
|
+
response = await super().ask(
|
|
938
|
+
question=question,
|
|
939
|
+
user_context=user_context,
|
|
940
|
+
context=context,
|
|
941
|
+
session_id=session_id,
|
|
942
|
+
user_id=user_id,
|
|
943
|
+
use_conversation_history=use_conversation_history,
|
|
944
|
+
**kwargs
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# If return_results is False, return the response as-is
|
|
948
|
+
if not return_results:
|
|
949
|
+
return response
|
|
950
|
+
|
|
951
|
+
# Try to extract and execute SQL queries from the response
|
|
952
|
+
try:
|
|
953
|
+
response_text = str(response.output) if response.output else ""
|
|
954
|
+
|
|
955
|
+
# Extract SQL queries from the response
|
|
956
|
+
sql_queries = self._extract_queries(response_text)
|
|
957
|
+
|
|
958
|
+
if sql_queries:
|
|
959
|
+
# Execute the first/main SQL query
|
|
960
|
+
main_query = sql_queries[0]
|
|
961
|
+
self.logger.debug(
|
|
962
|
+
f"Auto-executing extracted query: {main_query[:100]}..."
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# Execute the query
|
|
966
|
+
result = await self.execute_query(
|
|
967
|
+
query=main_query
|
|
968
|
+
)
|
|
969
|
+
# Preserve original response
|
|
970
|
+
response.response = response_text
|
|
971
|
+
# is the dataframe:
|
|
972
|
+
response.output = result.get('data', None)
|
|
973
|
+
response.raw_response = result # Preserve raw ToolResult
|
|
974
|
+
|
|
975
|
+
# Add execution metadata if response has metadata attribute
|
|
976
|
+
if hasattr(response, 'metadata') and response.metadata:
|
|
977
|
+
response.metadata.update({
|
|
978
|
+
'auto_executed_query': True,
|
|
979
|
+
'executed_query': main_query,
|
|
980
|
+
'execution_success': result.get('status') == 'success',
|
|
981
|
+
'row_count': result.get('row_count', 0),
|
|
982
|
+
'columns': result.get('columns', []),
|
|
983
|
+
'error': result.get('error', None)
|
|
984
|
+
})
|
|
985
|
+
|
|
986
|
+
except Exception as e:
|
|
987
|
+
self.logger.warning(
|
|
988
|
+
f"Failed to auto-execute query: {e}"
|
|
989
|
+
)
|
|
990
|
+
# Don't fail the entire request, just log the warning
|
|
991
|
+
# The user still gets the explanation even if execution fails
|
|
992
|
+
|
|
993
|
+
return response
|
|
994
|
+
|
|
995
|
+
async def search_schema(
|
|
996
|
+
self,
|
|
997
|
+
search_term: str,
|
|
998
|
+
search_type: str = "all",
|
|
999
|
+
limit: int = 10
|
|
1000
|
+
) -> List[Dict[str, Any]]:
|
|
1001
|
+
"""
|
|
1002
|
+
Search the database schema using SQL queries against information_schema.
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
search_term: Term to search for (supports LIKE patterns implicitly)
|
|
1006
|
+
search_type: Type of search ('tables', 'columns', 'all')
|
|
1007
|
+
limit: Maximum number of results
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
List of matching schema objects
|
|
1011
|
+
"""
|
|
1012
|
+
results = []
|
|
1013
|
+
|
|
1014
|
+
# Check cache first
|
|
1015
|
+
if self.cache:
|
|
1016
|
+
cached_results = await self.cache.get(search_term, search_type, limit)
|
|
1017
|
+
if cached_results is not None:
|
|
1018
|
+
self.logger.info(f"Schema search cache hit for term: {search_term}")
|
|
1019
|
+
return cached_results
|
|
1020
|
+
|
|
1021
|
+
term_pattern = f"%{search_term}%"
|
|
1022
|
+
|
|
1023
|
+
try:
|
|
1024
|
+
# Determine logic based on search_type
|
|
1025
|
+
search_tables = search_type in ["all", "tables"]
|
|
1026
|
+
search_columns = search_type in ["all", "columns"]
|
|
1027
|
+
|
|
1028
|
+
# --- Search Tables ---
|
|
1029
|
+
if search_tables:
|
|
1030
|
+
if self.database_flavor in ['postgresql', 'postgres', 'pg']:
|
|
1031
|
+
# Support schema.table search
|
|
1032
|
+
query = """
|
|
1033
|
+
SELECT table_schema, table_name, 'TABLE' as type
|
|
1034
|
+
FROM information_schema.tables
|
|
1035
|
+
WHERE (table_name ILIKE :term
|
|
1036
|
+
OR table_schema || '.' || table_name ILIKE :term)
|
|
1037
|
+
AND table_schema NOT IN ('information_schema', 'pg_catalog')
|
|
1038
|
+
AND table_type = 'BASE TABLE'
|
|
1039
|
+
LIMIT :limit
|
|
1040
|
+
"""
|
|
1041
|
+
elif self.database_flavor == 'mysql':
|
|
1042
|
+
query = """
|
|
1043
|
+
SELECT table_schema, table_name, 'TABLE' as type
|
|
1044
|
+
FROM information_schema.tables
|
|
1045
|
+
WHERE (table_name LIKE :term
|
|
1046
|
+
OR CONCAT(table_schema, '.', table_name) LIKE :term)
|
|
1047
|
+
AND table_schema = DATABASE()
|
|
1048
|
+
AND table_type = 'BASE TABLE'
|
|
1049
|
+
LIMIT :limit
|
|
1050
|
+
"""
|
|
1051
|
+
else: # Generic/SQL Server
|
|
1052
|
+
query = """
|
|
1053
|
+
SELECT table_schema, table_name, 'TABLE' as type
|
|
1054
|
+
FROM information_schema.tables
|
|
1055
|
+
WHERE table_name LIKE :term
|
|
1056
|
+
LIMIT :limit
|
|
1057
|
+
"""
|
|
1058
|
+
|
|
1059
|
+
if self.engine:
|
|
1060
|
+
async with self.engine.connect() as conn:
|
|
1061
|
+
result_proxy = await conn.execute(text(query), {"term": term_pattern, "limit": limit})
|
|
1062
|
+
rows = result_proxy.fetchall()
|
|
1063
|
+
for row in rows:
|
|
1064
|
+
results.append({
|
|
1065
|
+
"type": "table",
|
|
1066
|
+
"name": row[1],
|
|
1067
|
+
"schema": row[0],
|
|
1068
|
+
"description": f"Table: {row[0]}.{row[1]}"
|
|
1069
|
+
})
|
|
1070
|
+
|
|
1071
|
+
# --- Search Columns ---
|
|
1072
|
+
if search_columns and len(results) < limit:
|
|
1073
|
+
current_limit = limit - len(results)
|
|
1074
|
+
if self.database_flavor in ['postgresql', 'postgres', 'pg']:
|
|
1075
|
+
query = """
|
|
1076
|
+
SELECT table_schema, table_name, column_name, data_type
|
|
1077
|
+
FROM information_schema.columns
|
|
1078
|
+
WHERE column_name ILIKE :term
|
|
1079
|
+
AND table_schema NOT IN ('information_schema', 'pg_catalog')
|
|
1080
|
+
LIMIT :limit
|
|
1081
|
+
"""
|
|
1082
|
+
elif self.database_flavor == 'mysql':
|
|
1083
|
+
query = """
|
|
1084
|
+
SELECT table_schema, table_name, column_name, data_type
|
|
1085
|
+
FROM information_schema.columns
|
|
1086
|
+
WHERE column_name LIKE :term
|
|
1087
|
+
AND table_schema = DATABASE()
|
|
1088
|
+
LIMIT :limit
|
|
1089
|
+
"""
|
|
1090
|
+
else: # Generic/SQL Server
|
|
1091
|
+
query = """
|
|
1092
|
+
SELECT table_schema, table_name, column_name, data_type
|
|
1093
|
+
FROM information_schema.columns
|
|
1094
|
+
WHERE column_name LIKE :term
|
|
1095
|
+
LIMIT :limit
|
|
1096
|
+
"""
|
|
1097
|
+
|
|
1098
|
+
if self.engine:
|
|
1099
|
+
async with self.engine.connect() as conn:
|
|
1100
|
+
result_proxy = await conn.execute(text(query), {"term": term_pattern, "limit": current_limit})
|
|
1101
|
+
rows = result_proxy.fetchall()
|
|
1102
|
+
for row in rows:
|
|
1103
|
+
results.append({
|
|
1104
|
+
"type": "column",
|
|
1105
|
+
"table": row[1],
|
|
1106
|
+
"schema": row[0],
|
|
1107
|
+
"name": row[2],
|
|
1108
|
+
"description": f"Column: {row[2]} (Type: {row[3]}) in {row[0]}.{row[1]}",
|
|
1109
|
+
"metadata": f"Type: {row[3]}"
|
|
1110
|
+
})
|
|
1111
|
+
|
|
1112
|
+
# Cache the results ONLY if we found something
|
|
1113
|
+
# This prevents caching False Negatives (empty results) which might be due to transient issues or bad queries
|
|
1114
|
+
if self.cache and results:
|
|
1115
|
+
await self.cache.set(search_term, search_type, limit, results)
|
|
1116
|
+
|
|
1117
|
+
return results
|
|
1118
|
+
|
|
1119
|
+
except Exception as e:
|
|
1120
|
+
self.logger.error(f"Error in SQL-based search_schema: {e}")
|
|
1121
|
+
return []
|
|
1122
|
+
|
|
1123
|
+
def _extract_queries(self, response_text: str) -> List[str]:
|
|
1124
|
+
"""
|
|
1125
|
+
Extract SQL queries from LLM response text.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
response_text: The full response text from the LLM
|
|
1129
|
+
|
|
1130
|
+
Returns:
|
|
1131
|
+
List of extracted SQL queries
|
|
1132
|
+
"""
|
|
1133
|
+
queries = []
|
|
1134
|
+
|
|
1135
|
+
# Method 1: Extract from markdown code blocks
|
|
1136
|
+
sql_pattern = r'```sql\n(.*?)\n```'
|
|
1137
|
+
matches = re.findall(sql_pattern, response_text, re.DOTALL | re.IGNORECASE)
|
|
1138
|
+
|
|
1139
|
+
for match in matches:
|
|
1140
|
+
cleaned_query = match.strip()
|
|
1141
|
+
if cleaned_query and not cleaned_query.lower().startswith('--'):
|
|
1142
|
+
queries.append(cleaned_query)
|
|
1143
|
+
|
|
1144
|
+
# Method 2: If no markdown blocks, look for SQL-like patterns
|
|
1145
|
+
# CAUTION: This fallback generates false positives for explanations.
|
|
1146
|
+
# We will disable aggressive line scanning and only support markdown blocks or single-line exact queries.
|
|
1147
|
+
if not queries:
|
|
1148
|
+
cleaned_text = response_text.strip()
|
|
1149
|
+
# If the whole text looks like a query (starts with keyword, ends with ;)
|
|
1150
|
+
if re.match(r'^(SELECT|WITH|SHOW|DESCRIBE|EXPLAIN)\b.*?;$', cleaned_text, re.IGNORECASE | re.DOTALL):
|
|
1151
|
+
queries.append(cleaned_text)
|
|
1152
|
+
|
|
1153
|
+
# Clean up queries
|
|
1154
|
+
cleaned_queries = []
|
|
1155
|
+
for query in queries:
|
|
1156
|
+
# Remove common prefixes/suffixes
|
|
1157
|
+
query = re.sub(r'^```sql\s*', '', query, flags=re.IGNORECASE)
|
|
1158
|
+
query = re.sub(r'\s*```$', '', query)
|
|
1159
|
+
query = query.strip()
|
|
1160
|
+
|
|
1161
|
+
# Basic validation - should contain SELECT, WITH, etc.
|
|
1162
|
+
if re.search(r'\b(SELECT|WITH|SHOW|DESCRIBE|EXPLAIN)\b', query, re.IGNORECASE):
|
|
1163
|
+
cleaned_queries.append(query)
|
|
1164
|
+
|
|
1165
|
+
return cleaned_queries
|
|
1166
|
+
|
|
1167
|
+
def _extract_tables_from_query(self, query: str) -> List[str]:
|
|
1168
|
+
"""Extract table names from SQL query."""
|
|
1169
|
+
pattern = r'(?:FROM|JOIN)\s+(?:[\w\.]*\.)?(\w+)'
|
|
1170
|
+
matches = re.findall(pattern, query.upper())
|
|
1171
|
+
return list(set(matches))
|
|
1172
|
+
|
|
1173
|
+
async def cleanup(self) -> None:
|
|
1174
|
+
"""Cleanup resources."""
|
|
1175
|
+
if self.engine:
|
|
1176
|
+
await self.engine.dispose()
|
|
1177
|
+
await super().cleanup()
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
# Factory function for creating enhanced SQL agents
|
|
1183
|
+
def create_sql_agent(
|
|
1184
|
+
database_flavor: str,
|
|
1185
|
+
credentials: Union[str, Dict[str, Any]],
|
|
1186
|
+
schema_name: str = None,
|
|
1187
|
+
**kwargs
|
|
1188
|
+
) -> SQLAgent:
|
|
1189
|
+
"""
|
|
1190
|
+
Factory function to create SQL database agents.
|
|
1191
|
+
|
|
1192
|
+
Args:
|
|
1193
|
+
database_flavor: Database type ('postgresql', 'mysql', 'sqlserver')
|
|
1194
|
+
credentials: Connection credentials (string or dict)
|
|
1195
|
+
schema_name: Target schema name
|
|
1196
|
+
**kwargs: Additional arguments
|
|
1197
|
+
|
|
1198
|
+
Returns:
|
|
1199
|
+
Configured SQLAgent instance
|
|
1200
|
+
"""
|
|
1201
|
+
# Set default schema names
|
|
1202
|
+
if schema_name is None:
|
|
1203
|
+
if database_flavor.lower() in ['postgresql', 'postgres']:
|
|
1204
|
+
schema_name = 'public'
|
|
1205
|
+
elif database_flavor.lower() == 'mysql':
|
|
1206
|
+
schema_name = 'mysql'
|
|
1207
|
+
elif database_flavor.lower() in ['sqlserver', 'mssql']:
|
|
1208
|
+
schema_name = 'dbo'
|
|
1209
|
+
else:
|
|
1210
|
+
schema_name = 'public'
|
|
1211
|
+
|
|
1212
|
+
return SQLAgent(
|
|
1213
|
+
database_flavor=database_flavor,
|
|
1214
|
+
credentials=credentials,
|
|
1215
|
+
schema_name=schema_name,
|
|
1216
|
+
**kwargs
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
|
|
1220
|
+
# Example usage
|
|
1221
|
+
"""
|
|
1222
|
+
# Dictionary credentials example
|
|
1223
|
+
pg_creds = {
|
|
1224
|
+
'host': 'localhost',
|
|
1225
|
+
'port': 5432,
|
|
1226
|
+
'database': 'sales_db',
|
|
1227
|
+
'username': 'user',
|
|
1228
|
+
'password': 'password'
|
|
1229
|
+
}
|
|
1230
|
+
|
|
1231
|
+
pg_agent = create_sql_agent(
|
|
1232
|
+
database_flavor='postgresql',
|
|
1233
|
+
credentials=pg_creds,
|
|
1234
|
+
schema_name='public'
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
# Connection string example
|
|
1238
|
+
mysql_agent = create_sql_agent(
|
|
1239
|
+
database_flavor='mysql',
|
|
1240
|
+
credentials='mysql://user:pass@localhost/dbname'
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
# Usage
|
|
1244
|
+
await pg_agent.initialize_schema()
|
|
1245
|
+
|
|
1246
|
+
# Generate and execute query
|
|
1247
|
+
query_result = await pg_agent.generate_query(
|
|
1248
|
+
"Show me all customers from the East region with their order totals"
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
execution_result = await pg_agent.execute_query(query_result['query'])
|
|
1252
|
+
print(f"Query: {execution_result['query']}")
|
|
1253
|
+
print(f"Data: {execution_result['data']}")
|
|
1254
|
+
"""
|
|
1255
|
+
|