ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,3071 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema-Centric AbstractDbAgent for Multi-Tenant Architecture
|
|
3
|
+
===========================================================
|
|
4
|
+
|
|
5
|
+
Designed for:
|
|
6
|
+
- 96+ schemas with ~50 tables each (~4,800+ total tables)
|
|
7
|
+
- Per-client schema isolation
|
|
8
|
+
- LRU + Vector store caching (no Redis)
|
|
9
|
+
- Dual execution paths: natural language generation + direct SQL tools
|
|
10
|
+
- "Show me" = data retrieval pattern recognition
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from abc import ABC
|
|
14
|
+
import inspect
|
|
15
|
+
from typing import Dict, Any, List, Optional, Union, Tuple, Type, get_origin, get_args
|
|
16
|
+
from dataclasses import is_dataclass
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from string import Template
|
|
19
|
+
import re
|
|
20
|
+
import uuid
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
23
|
+
from sqlalchemy.orm import sessionmaker
|
|
24
|
+
from sqlalchemy import text
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from ...tools.manager import ToolManager
|
|
27
|
+
from ...stores.abstract import AbstractStore
|
|
28
|
+
from ..abstract import AbstractBot
|
|
29
|
+
from ...models import AIMessage, CompletionUsage
|
|
30
|
+
from .cache import SchemaMetadataCache
|
|
31
|
+
from .router import SchemaQueryRouter
|
|
32
|
+
from .models import (
|
|
33
|
+
UserRole,
|
|
34
|
+
QueryIntent,
|
|
35
|
+
RouteDecision,
|
|
36
|
+
TableMetadata,
|
|
37
|
+
QueryExecutionResponse,
|
|
38
|
+
OutputComponent,
|
|
39
|
+
DatabaseResponse,
|
|
40
|
+
get_default_components,
|
|
41
|
+
components_from_string
|
|
42
|
+
)
|
|
43
|
+
from .prompts import DB_AGENT_PROMPT
|
|
44
|
+
from .retries import QueryRetryConfig, SQLRetryHandler
|
|
45
|
+
from parrot.tools.database.pg import PgSchemaSearchTool
|
|
46
|
+
from parrot.tools.database.bq import BQSchemaSearchTool
|
|
47
|
+
from ...memory import ConversationTurn
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ============================================================================
|
|
51
|
+
# SCHEMA-CENTRIC ABSTRACT DB AGENT
|
|
52
|
+
# ============================================================================
|
|
53
|
+
|
|
54
|
+
class AbstractDBAgent(AbstractBot, ABC):
|
|
55
|
+
"""Schema-centric AbstractDBAgent for multi-tenant architecture."""
|
|
56
|
+
_default_temperature: float = 0.0
|
|
57
|
+
max_tokens: int = 8192
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
name: str = "DBAgent",
|
|
62
|
+
dsn: str = None,
|
|
63
|
+
allowed_schemas: Union[str, List[str]] = "public",
|
|
64
|
+
primary_schema: Optional[str] = None,
|
|
65
|
+
vector_store: Optional[AbstractStore] = None,
|
|
66
|
+
auto_analyze_schema: bool = True,
|
|
67
|
+
client_id: Optional[str] = None,
|
|
68
|
+
database_type: str = "postgresql",
|
|
69
|
+
system_prompt_template: Optional[str] = None,
|
|
70
|
+
**kwargs
|
|
71
|
+
):
|
|
72
|
+
super().__init__(name=name, **kwargs)
|
|
73
|
+
self.enable_tools = True # Enable tools by default
|
|
74
|
+
self.role = kwargs.get(
|
|
75
|
+
'role', 'Database Analysis Assistant'
|
|
76
|
+
)
|
|
77
|
+
self.goal = kwargs.get(
|
|
78
|
+
'goal', 'Help users interact with databases using natural language'
|
|
79
|
+
)
|
|
80
|
+
self.backstory = kwargs.get(
|
|
81
|
+
'backstory',
|
|
82
|
+
"""
|
|
83
|
+
- Help users query, analyze, and understand database information
|
|
84
|
+
- Generate accurate SQL queries based on available schema metadata
|
|
85
|
+
- Provide data insights and recommendations
|
|
86
|
+
- Maintain conversation context for better user experience.
|
|
87
|
+
"""
|
|
88
|
+
)
|
|
89
|
+
# System Prompt Template:
|
|
90
|
+
self.system_prompt_template = system_prompt_template or DB_AGENT_PROMPT
|
|
91
|
+
|
|
92
|
+
# Multi-schema configuration
|
|
93
|
+
if isinstance(allowed_schemas, str):
|
|
94
|
+
self.allowed_schemas = [allowed_schemas]
|
|
95
|
+
else:
|
|
96
|
+
self.allowed_schemas = allowed_schemas
|
|
97
|
+
|
|
98
|
+
# Primary schema is the main focus, defaults to first allowed schema
|
|
99
|
+
self.primary_schema = primary_schema or self.allowed_schemas[0]
|
|
100
|
+
|
|
101
|
+
# Ensure primary schema is in allowed list
|
|
102
|
+
if self.primary_schema not in self.allowed_schemas:
|
|
103
|
+
self.allowed_schemas.insert(0, self.primary_schema)
|
|
104
|
+
|
|
105
|
+
self.client_id = client_id or self.primary_schema
|
|
106
|
+
self.dsn = dsn
|
|
107
|
+
self.database_type = database_type
|
|
108
|
+
|
|
109
|
+
# Database components
|
|
110
|
+
self.engine: Optional[AsyncEngine] = None
|
|
111
|
+
self.session_maker: Optional[sessionmaker] = None
|
|
112
|
+
|
|
113
|
+
# Per-agent ToolManager
|
|
114
|
+
self.tool_manager = ToolManager(
|
|
115
|
+
logger=self.logger,
|
|
116
|
+
debug=getattr(self, '_debug', False)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Schema-aware components
|
|
120
|
+
self.metadata_cache = SchemaMetadataCache(
|
|
121
|
+
vector_store=vector_store, # Optional - can be None
|
|
122
|
+
lru_maxsize=500, # Large cache for many tables
|
|
123
|
+
lru_ttl=1800 # 30 minutes
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Vector Store:
|
|
127
|
+
self.knowledge_store = vector_store
|
|
128
|
+
|
|
129
|
+
self.query_router = SchemaQueryRouter(
|
|
130
|
+
primary_schema=self.primary_schema,
|
|
131
|
+
allowed_schemas=self.allowed_schemas
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Schema analysis flag
|
|
135
|
+
self.schema_analyzed = False
|
|
136
|
+
self.auto_analyze_schema = auto_analyze_schema
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def configure(self, app=None) -> None:
|
|
140
|
+
"""Configure agent with proper tool sharing."""
|
|
141
|
+
await super().configure(app)
|
|
142
|
+
|
|
143
|
+
# Connect to database
|
|
144
|
+
await self.connect_database()
|
|
145
|
+
|
|
146
|
+
# Register tools
|
|
147
|
+
self._register_database_tools()
|
|
148
|
+
|
|
149
|
+
# Share tools with LLM
|
|
150
|
+
await self._share_tools_with_llm()
|
|
151
|
+
|
|
152
|
+
# Auto-analyze schema if enabled
|
|
153
|
+
if self.auto_analyze_schema and not self.schema_analyzed:
|
|
154
|
+
await self.analyze_schema()
|
|
155
|
+
|
|
156
|
+
def _register_database_tools(self):
|
|
157
|
+
"""Register database-specific tools."""
|
|
158
|
+
if self.database_type == "bigquery":
|
|
159
|
+
tool_cls = BQSchemaSearchTool
|
|
160
|
+
else:
|
|
161
|
+
tool_cls = PgSchemaSearchTool
|
|
162
|
+
|
|
163
|
+
self.schema_tool = tool_cls(
|
|
164
|
+
engine=self.engine,
|
|
165
|
+
metadata_cache=self.metadata_cache,
|
|
166
|
+
allowed_schemas=self.allowed_schemas.copy(),
|
|
167
|
+
session_maker=self.session_maker
|
|
168
|
+
)
|
|
169
|
+
self.tool_manager.add_tool(self.schema_tool)
|
|
170
|
+
self.logger.debug(
|
|
171
|
+
f"Registered SchemaSearchTool with {len(self.allowed_schemas)} schemas"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async def _share_tools_with_llm(self):
|
|
175
|
+
"""Share ToolManager tools with LLM Client."""
|
|
176
|
+
if not hasattr(self, '_llm') or not self._llm:
|
|
177
|
+
self.logger.warning("LLM client not initialized, cannot share tools")
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
if not hasattr(self._llm, 'tool_manager'):
|
|
181
|
+
self.logger.warning("LLM client has no tool_manager")
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
tools = list(self.tool_manager.get_tools())
|
|
185
|
+
for tool in tools:
|
|
186
|
+
self._llm.tool_manager.add_tool(tool)
|
|
187
|
+
|
|
188
|
+
self.logger.info(
|
|
189
|
+
f"Shared {len(tools)} tools with LLM Client"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def _ensure_async_driver(self, dsn: str) -> str:
|
|
193
|
+
return dsn
|
|
194
|
+
|
|
195
|
+
async def connect_database(self) -> None:
|
|
196
|
+
"""Connect to SQL database using SQLAlchemy async."""
|
|
197
|
+
if not self.dsn:
|
|
198
|
+
raise ValueError("Connection string is required")
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
# Ensure async driver
|
|
202
|
+
connection_string = self._ensure_async_driver(self.dsn)
|
|
203
|
+
# Build search path from allowed schemas
|
|
204
|
+
search_path = ','.join(self.allowed_schemas)
|
|
205
|
+
|
|
206
|
+
self.engine = create_async_engine(
|
|
207
|
+
connection_string,
|
|
208
|
+
echo=False,
|
|
209
|
+
pool_pre_ping=True,
|
|
210
|
+
pool_recycle=3600,
|
|
211
|
+
# Multi-schema search path
|
|
212
|
+
connect_args={
|
|
213
|
+
"server_settings": {
|
|
214
|
+
"search_path": search_path
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
self.session_maker = sessionmaker(
|
|
220
|
+
self.engine,
|
|
221
|
+
class_=AsyncSession,
|
|
222
|
+
expire_on_commit=False
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Test connection
|
|
226
|
+
async with self.engine.begin() as conn:
|
|
227
|
+
result = await conn.execute(text("SELECT current_schema()"))
|
|
228
|
+
current_schema = result.scalar()
|
|
229
|
+
self.logger.info(
|
|
230
|
+
f"Connected to database. Current schema: {current_schema}, "
|
|
231
|
+
f"Search path: {search_path}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
self.logger.error(f"Failed to connect to database: {e}")
|
|
236
|
+
raise
|
|
237
|
+
|
|
238
|
+
async def analyze_schema(self) -> None:
|
|
239
|
+
"""Analyze all allowed schemas and populate metadata cache."""
|
|
240
|
+
try:
|
|
241
|
+
self.logger.notice(
|
|
242
|
+
f"Analyzing schemas: {self.allowed_schemas} (primary: {self.primary_schema})"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Delegate to schema manager tool
|
|
246
|
+
analysis_results = await self.schema_tool.analyze_all_schemas()
|
|
247
|
+
|
|
248
|
+
# Log results
|
|
249
|
+
total_tables = sum(analysis_results.values())
|
|
250
|
+
for schema_name, table_count in analysis_results.items():
|
|
251
|
+
if table_count > 0:
|
|
252
|
+
self.logger.info(f"Schema '{schema_name}': {table_count} tables/views")
|
|
253
|
+
else:
|
|
254
|
+
self.logger.warning(f"Schema '{schema_name}': Analysis failed or no tables found")
|
|
255
|
+
|
|
256
|
+
self.schema_analyzed = True
|
|
257
|
+
self.logger.info(f"Schema analysis completed. Total: {total_tables} tables/views")
|
|
258
|
+
|
|
259
|
+
except Exception as e:
|
|
260
|
+
self.logger.error(f"Schema analysis failed: {e}")
|
|
261
|
+
raise
|
|
262
|
+
|
|
263
|
+
async def get_table_metadata(self, schema: str, tablename: str) -> Optional[TableMetadata]:
|
|
264
|
+
"""Get table metadata - delegates to schema tool."""
|
|
265
|
+
if not self.schema_tool:
|
|
266
|
+
raise RuntimeError("Schema tool not initialized. Call configure() first.")
|
|
267
|
+
|
|
268
|
+
return await self.schema_tool.get_table_details(schema, tablename)
|
|
269
|
+
|
|
270
|
+
async def get_schema_overview(self, schema_name: str) -> Optional[Dict[str, Any]]:
|
|
271
|
+
"""Get schema overview - delegates to schema Tool."""
|
|
272
|
+
if not self.schema_tool:
|
|
273
|
+
raise RuntimeError("Schema Tool not initialized. Call configure() first.")
|
|
274
|
+
|
|
275
|
+
return await self.schema_tool.get_schema_overview(schema_name)
|
|
276
|
+
|
|
277
|
+
async def create_system_prompt(
|
|
278
|
+
self,
|
|
279
|
+
user_context: str = "",
|
|
280
|
+
context: str = "",
|
|
281
|
+
vector_context: str = "",
|
|
282
|
+
conversation_context: str = "",
|
|
283
|
+
metadata_context: str = "",
|
|
284
|
+
vector_metadata: Optional[Dict[str, Any]] = None,
|
|
285
|
+
route: Optional[RouteDecision] = None,
|
|
286
|
+
**kwargs
|
|
287
|
+
) -> str:
|
|
288
|
+
"""
|
|
289
|
+
Create the complete system prompt using template substitution.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
user_context: User-specific context for database interaction
|
|
293
|
+
context: Additional context for the request
|
|
294
|
+
vector_context: Context from vector store similarity search
|
|
295
|
+
conversation_context: Previous conversation context
|
|
296
|
+
metadata_context: Schema metadata context
|
|
297
|
+
vector_metadata: Metadata from vector search
|
|
298
|
+
route: Query route decision for specialized instructions
|
|
299
|
+
**kwargs: Additional template variables
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Complete system prompt string
|
|
303
|
+
"""
|
|
304
|
+
# Build context sections
|
|
305
|
+
context_parts = []
|
|
306
|
+
|
|
307
|
+
# User context section
|
|
308
|
+
if user_context:
|
|
309
|
+
user_section = f"""
|
|
310
|
+
**User Context:**
|
|
311
|
+
{user_context}
|
|
312
|
+
|
|
313
|
+
*Instructions: Tailor your response to the user's role, expertise level, and objectives described above.*
|
|
314
|
+
"""
|
|
315
|
+
context_parts.append(user_section)
|
|
316
|
+
|
|
317
|
+
# Additional context
|
|
318
|
+
if context:
|
|
319
|
+
context_parts.append(f"**Additional Context:**\n{context}")
|
|
320
|
+
|
|
321
|
+
# Database context from schema metadata
|
|
322
|
+
database_context_parts = []
|
|
323
|
+
if metadata_context:
|
|
324
|
+
database_context_parts.append(
|
|
325
|
+
f"**Available Schema Information:**\n{metadata_context}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Add current database info
|
|
329
|
+
db_info = f"""**Database Configuration:**
|
|
330
|
+
- Primary Schema: {self.primary_schema}
|
|
331
|
+
- Allowed Schemas: {', '.join(self.allowed_schemas)}
|
|
332
|
+
- Database Type: {self.database_type}
|
|
333
|
+
- Total Schemas: {len(self.allowed_schemas)}"""
|
|
334
|
+
database_context_parts.append(db_info)
|
|
335
|
+
|
|
336
|
+
# Vector context from knowledge store
|
|
337
|
+
vector_section = ""
|
|
338
|
+
if vector_context:
|
|
339
|
+
vector_section = f"""**Relevant Knowledge Base Context:**
|
|
340
|
+
{vector_context}
|
|
341
|
+
"""
|
|
342
|
+
if vector_metadata and vector_metadata.get('tables_referenced'):
|
|
343
|
+
referenced_tables = [t for t in vector_metadata['tables_referenced'] if t]
|
|
344
|
+
if referenced_tables:
|
|
345
|
+
vector_section += f"\n*Referenced Tables: {', '.join(set(referenced_tables))}*"
|
|
346
|
+
|
|
347
|
+
# Conversation history section
|
|
348
|
+
chat_section = ""
|
|
349
|
+
if conversation_context:
|
|
350
|
+
chat_section = f"""**Previous Conversation:**
|
|
351
|
+
{conversation_context}
|
|
352
|
+
|
|
353
|
+
*Note: Consider previous context when formulating your response.*
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
# Route-specific instructions
|
|
357
|
+
route_instructions = ""
|
|
358
|
+
if route:
|
|
359
|
+
if route.intent == QueryIntent.SHOW_DATA:
|
|
360
|
+
route_instructions = "\n**Current Task**: Generate and execute SQL to retrieve and display data."
|
|
361
|
+
elif route.intent == QueryIntent.GENERATE_QUERY:
|
|
362
|
+
route_instructions = "\n**Current Task**: Generate SQL query based on user request and available schema."
|
|
363
|
+
elif route.intent == QueryIntent.ANALYZE_DATA:
|
|
364
|
+
route_instructions = "\n**Current Task**: Analyze data and provide insights with supporting queries."
|
|
365
|
+
elif route.intent == QueryIntent.EXPLORE_SCHEMA:
|
|
366
|
+
route_instructions = "\n**Current Task**: Help user explore and understand the database schema."
|
|
367
|
+
|
|
368
|
+
# Template substitution
|
|
369
|
+
template = Template(self.system_prompt_template)
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
system_prompt = template.safe_substitute(
|
|
373
|
+
user_context=user_section if user_context else "",
|
|
374
|
+
database_context="\n\n".join(database_context_parts),
|
|
375
|
+
context="\n\n".join(context_parts) if context_parts else "",
|
|
376
|
+
vector_context=vector_section,
|
|
377
|
+
chat_history=chat_section,
|
|
378
|
+
route_instructions=route_instructions,
|
|
379
|
+
database_type=self.database_type,
|
|
380
|
+
**kwargs
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
return system_prompt
|
|
384
|
+
|
|
385
|
+
except Exception as e:
|
|
386
|
+
self.logger.error(f"Error in template substitution: {e}")
|
|
387
|
+
# Fallback to basic prompt
|
|
388
|
+
return f"""You are a database assistant for {self.database_type} databases.
|
|
389
|
+
Primary Schema: {self.primary_schema}
|
|
390
|
+
Available Schemas: {', '.join(self.allowed_schemas)}
|
|
391
|
+
|
|
392
|
+
{user_context if user_context else ''}
|
|
393
|
+
{context if context else ''}
|
|
394
|
+
|
|
395
|
+
Please help the user with their database query using available tools."""
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _parse_components(
|
|
399
|
+
self,
|
|
400
|
+
user_role: UserRole,
|
|
401
|
+
output_components: Optional[Union[str, OutputComponent]],
|
|
402
|
+
add_components: Optional[Union[str, OutputComponent]],
|
|
403
|
+
remove_components: Optional[Union[str, OutputComponent]]
|
|
404
|
+
) -> OutputComponent:
|
|
405
|
+
"""Parse and combine output components from various inputs."""
|
|
406
|
+
|
|
407
|
+
if output_components is not None:
|
|
408
|
+
# Explicit override
|
|
409
|
+
if isinstance(output_components, str):
|
|
410
|
+
final_components = components_from_string(output_components)
|
|
411
|
+
else:
|
|
412
|
+
final_components = output_components
|
|
413
|
+
else:
|
|
414
|
+
# Start with role defaults
|
|
415
|
+
final_components = get_default_components(user_role)
|
|
416
|
+
|
|
417
|
+
# Apply additions
|
|
418
|
+
if add_components:
|
|
419
|
+
if isinstance(add_components, str):
|
|
420
|
+
add_comp = components_from_string(add_components)
|
|
421
|
+
else:
|
|
422
|
+
add_comp = add_components
|
|
423
|
+
final_components |= add_comp
|
|
424
|
+
|
|
425
|
+
# Apply removals
|
|
426
|
+
if remove_components:
|
|
427
|
+
if isinstance(remove_components, str):
|
|
428
|
+
remove_comp = components_from_string(remove_components)
|
|
429
|
+
else:
|
|
430
|
+
remove_comp = remove_components
|
|
431
|
+
final_components &= ~remove_comp
|
|
432
|
+
|
|
433
|
+
return final_components
|
|
434
|
+
|
|
435
|
+
def _is_structured_output_format(self, output_format) -> bool:
|
|
436
|
+
"""Check if output_format is a BaseModel or dataclass."""
|
|
437
|
+
if output_format is None or isinstance(output_format, str):
|
|
438
|
+
return False
|
|
439
|
+
# Check if it's a Pydantic BaseModel class
|
|
440
|
+
try:
|
|
441
|
+
if inspect.isclass(output_format) and issubclass(output_format, BaseModel):
|
|
442
|
+
return True
|
|
443
|
+
except (TypeError, ImportError):
|
|
444
|
+
pass
|
|
445
|
+
# Check if it's a dataclass
|
|
446
|
+
try:
|
|
447
|
+
if inspect.isclass(output_format) and is_dataclass(output_format):
|
|
448
|
+
return True
|
|
449
|
+
except (TypeError, ImportError):
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
return False
|
|
453
|
+
|
|
454
|
+
async def ask(
|
|
455
|
+
self,
|
|
456
|
+
query: str,
|
|
457
|
+
context: Optional[str] = None,
|
|
458
|
+
user_role: UserRole = UserRole.DATA_ANALYST,
|
|
459
|
+
user_context: Optional[str] = None,
|
|
460
|
+
output_components: Optional[Union[str, OutputComponent]] = None,
|
|
461
|
+
output_format: Optional[Union[str, Type[BaseModel], Type]] = None, # "markdown", "json", "dataframe"
|
|
462
|
+
session_id: Optional[str] = None,
|
|
463
|
+
user_id: Optional[str] = None,
|
|
464
|
+
use_conversation_history: bool = True,
|
|
465
|
+
# Component customization
|
|
466
|
+
add_components: Optional[Union[str, OutputComponent]] = None,
|
|
467
|
+
remove_components: Optional[Union[str, OutputComponent]] = None,
|
|
468
|
+
enable_retry: bool = True,
|
|
469
|
+
**kwargs
|
|
470
|
+
) -> AIMessage:
|
|
471
|
+
"""
|
|
472
|
+
Ask method with role-based component responses and structured output support.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
query: The user's question about the database
|
|
476
|
+
user_role: User role determining default response components
|
|
477
|
+
output_components: Override default components (string or OutputComponent flags)
|
|
478
|
+
output_format: Output format preference:
|
|
479
|
+
- String: "markdown", "json", "dataframe"
|
|
480
|
+
- BaseModel: Pydantic model class for structured output
|
|
481
|
+
- Dataclass: Dataclass for structured output
|
|
482
|
+
add_components: Additional components to include (string or OutputComponent flags)
|
|
483
|
+
remove_components: Components to exclude (string or OutputComponent flags)
|
|
484
|
+
context: Additional context for the request
|
|
485
|
+
user_context: User-specific context
|
|
486
|
+
enable_retry: Whether to enable query retry on errors
|
|
487
|
+
session_id: Session identifier for conversation history
|
|
488
|
+
user_id: User identifier
|
|
489
|
+
use_conversation_history: Whether to use conversation history
|
|
490
|
+
**kwargs: Additional arguments for LLM
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
AIMessage: Enhanced response with role-appropriate components
|
|
494
|
+
|
|
495
|
+
Examples:
|
|
496
|
+
# Business user wants all inventory data
|
|
497
|
+
response = await agent.ask(
|
|
498
|
+
"Show me all inventory items",
|
|
499
|
+
user_role=UserRole.BUSINESS_USER
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Developer wants table metadata in markdown
|
|
503
|
+
response = await agent.ask(
|
|
504
|
+
"Return in markdown format the metadata of table inventory in schema hisense",
|
|
505
|
+
user_role=UserRole.DEVELOPER,
|
|
506
|
+
output_format="markdown"
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Data scientist wants DataFrame output
|
|
510
|
+
response = await agent.ask(
|
|
511
|
+
"Get sales data for analysis",
|
|
512
|
+
user_role=UserRole.DATA_SCIENTIST
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# DBA wants performance analysis
|
|
516
|
+
response = await agent.ask(
|
|
517
|
+
"Analyze slow queries on user table",
|
|
518
|
+
user_role=UserRole.DATABASE_ADMIN
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Custom component combination
|
|
522
|
+
response = await agent.ask(
|
|
523
|
+
"Get user data",
|
|
524
|
+
user_role=UserRole.DATA_ANALYST,
|
|
525
|
+
add_components="performance,optimize"
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Structured output with dataclass
|
|
529
|
+
@dataclass
|
|
530
|
+
class QueryAnalysis:
|
|
531
|
+
sql_query: str
|
|
532
|
+
execution_plan: str
|
|
533
|
+
performance_metrics: Dict[str, Any]
|
|
534
|
+
optimization_tips: List[str]
|
|
535
|
+
|
|
536
|
+
response = await agent.ask("Analyze query performance",
|
|
537
|
+
user_role=UserRole.QUERY_DEVELOPER,
|
|
538
|
+
output_format=QueryAnalysis)
|
|
539
|
+
"""
|
|
540
|
+
# Detect if output_format is a structured type
|
|
541
|
+
is_structured_output = self._is_structured_output_format(output_format)
|
|
542
|
+
structured_output_class = output_format if is_structured_output else None
|
|
543
|
+
|
|
544
|
+
# Parse user role
|
|
545
|
+
if isinstance(user_role, str):
|
|
546
|
+
user_role = UserRole(user_role.lower())
|
|
547
|
+
|
|
548
|
+
# Add retry configuration to kwargs
|
|
549
|
+
retry_config = kwargs.pop('retry_config', QueryRetryConfig())
|
|
550
|
+
|
|
551
|
+
# Override temperature to ensure consistent database operations
|
|
552
|
+
kwargs['temperature'] = kwargs.get('temperature', self._default_temperature)
|
|
553
|
+
|
|
554
|
+
# Generate session ID if not provided
|
|
555
|
+
if not session_id:
|
|
556
|
+
session_id = f"db_session_{hash(query + str(user_id))}"
|
|
557
|
+
|
|
558
|
+
# Parse output components
|
|
559
|
+
_components = self._parse_components(
|
|
560
|
+
user_role, output_components, add_components, remove_components
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
try:
|
|
564
|
+
# Step 1: Get conversation context
|
|
565
|
+
conversation_history = None
|
|
566
|
+
conversation_context = ""
|
|
567
|
+
|
|
568
|
+
if use_conversation_history and self.conversation_memory:
|
|
569
|
+
try:
|
|
570
|
+
conversation_history = await self.get_conversation_history(user_id, session_id)
|
|
571
|
+
if not conversation_history:
|
|
572
|
+
conversation_history = await self.create_conversation_history(user_id, session_id)
|
|
573
|
+
conversation_context = self.build_conversation_context(conversation_history)
|
|
574
|
+
except Exception as e:
|
|
575
|
+
self.logger.warning(f"Failed to load conversation history: {e}")
|
|
576
|
+
|
|
577
|
+
# Step 2: Get vector context from knowledge store
|
|
578
|
+
vector_context = ""
|
|
579
|
+
vector_metadata = {}
|
|
580
|
+
|
|
581
|
+
if self.knowledge_store:
|
|
582
|
+
try:
|
|
583
|
+
search_results = await self.knowledge_store.similarity_search(query, k=5)
|
|
584
|
+
if search_results:
|
|
585
|
+
vector_context = "\n\n".join(
|
|
586
|
+
[doc.page_content for doc in search_results]
|
|
587
|
+
)
|
|
588
|
+
vector_metadata = {
|
|
589
|
+
'sources': [doc.metadata.get('source', 'unknown') for doc in search_results],
|
|
590
|
+
'tables_referenced': [
|
|
591
|
+
doc.metadata.get('table_name')
|
|
592
|
+
for doc in search_results
|
|
593
|
+
if doc.metadata.get('table_name')
|
|
594
|
+
]
|
|
595
|
+
}
|
|
596
|
+
self.logger.debug(
|
|
597
|
+
f"Retrieved vector context from {len(search_results)} sources"
|
|
598
|
+
)
|
|
599
|
+
except Exception as e:
|
|
600
|
+
self.logger.warning(f"Error retrieving vector context: {e}")
|
|
601
|
+
except Exception as e:
|
|
602
|
+
self.logger.warning(f"Error preparing context: {e}")
|
|
603
|
+
conversation_context = ""
|
|
604
|
+
vector_context = ""
|
|
605
|
+
vector_metadata = {}
|
|
606
|
+
|
|
607
|
+
try:
|
|
608
|
+
# Step 3: Route the query
|
|
609
|
+
route: RouteDecision = await self.query_router.route(
|
|
610
|
+
query=query,
|
|
611
|
+
user_role=user_role,
|
|
612
|
+
output_components=_components
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
self.logger.info(
|
|
616
|
+
f"Query Routed: intent={route.intent.value}, "
|
|
617
|
+
f"schema={route.primary_schema}, "
|
|
618
|
+
f"role={route.user_role.value}, components={route.components}"
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Step 4: Discover metadata (if needed)
|
|
622
|
+
metadata_context = ""
|
|
623
|
+
discovered_tables = []
|
|
624
|
+
if route.needs_metadata_discovery or route.intent in [QueryIntent.EXPLORE_SCHEMA, QueryIntent.EXPLAIN_METADATA]:
|
|
625
|
+
self.logger.debug("🔍 Starting metadata discovery...")
|
|
626
|
+
metadata_context, discovered_tables = await self._discover_metadata(query)
|
|
627
|
+
self.logger.info(
|
|
628
|
+
f"✅ DISCOVERED: {len(discovered_tables)} tables with context length: {len(metadata_context)}"
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
self.logger.info(
|
|
632
|
+
f"Processing database query: use_tools=True, "
|
|
633
|
+
f"available_tools={len(self.tool_manager.get_tools())}"
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# Step 5: Generate/validate query (if needed)
|
|
637
|
+
db_response, llm_response = await self._process_query(
|
|
638
|
+
query=query,
|
|
639
|
+
route=route,
|
|
640
|
+
metadata_context=metadata_context,
|
|
641
|
+
discovered_tables=discovered_tables,
|
|
642
|
+
conversation_context=conversation_context,
|
|
643
|
+
vector_context=vector_context,
|
|
644
|
+
user_context=user_context,
|
|
645
|
+
enable_retry=enable_retry,
|
|
646
|
+
retry_config=retry_config,
|
|
647
|
+
context=context,
|
|
648
|
+
**kwargs
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Step 6: Format Final response, with response output
|
|
652
|
+
return await self._format_response(
|
|
653
|
+
query=query,
|
|
654
|
+
db_response=db_response,
|
|
655
|
+
is_structured_output=is_structured_output,
|
|
656
|
+
structured_output_class=structured_output_class,
|
|
657
|
+
route=route,
|
|
658
|
+
llm_response=llm_response,
|
|
659
|
+
output_format=output_format,
|
|
660
|
+
discovered_tables=discovered_tables,
|
|
661
|
+
**kwargs
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
except Exception as e:
|
|
665
|
+
self.logger.error(
|
|
666
|
+
f"Error in enhanced ask method: {e}"
|
|
667
|
+
)
|
|
668
|
+
return self._create_error_response(query, e, user_role)
|
|
669
|
+
|
|
670
|
+
async def _use_schema_search_tool(self, user_query: str) -> Optional[str]:
|
|
671
|
+
"""Use schema search tool to discover relevant metadata."""
|
|
672
|
+
try:
|
|
673
|
+
# Direct call to schema tool
|
|
674
|
+
search_results = await self.schema_tool.search_schema(
|
|
675
|
+
search_term=user_query,
|
|
676
|
+
search_type="all",
|
|
677
|
+
limit=5
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
if search_results:
|
|
681
|
+
self.logger.info(
|
|
682
|
+
f"Found {len(search_results)} tables via schema tool"
|
|
683
|
+
)
|
|
684
|
+
metadata_parts = []
|
|
685
|
+
for table in search_results:
|
|
686
|
+
metadata_parts.append(table.to_yaml_context())
|
|
687
|
+
return "\n---\n".join(metadata_parts)
|
|
688
|
+
|
|
689
|
+
except Exception as e:
|
|
690
|
+
self.logger.error(
|
|
691
|
+
f"Schema tool failed: {e}"
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return None
|
|
695
|
+
|
|
696
|
+
async def _discover_metadata(self, query: str) -> Tuple[str, List[TableMetadata]]:
|
|
697
|
+
"""
|
|
698
|
+
Discover relevant metadata for the query across allowed schemas.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
Tuple[str, List[TableMetadata]]: (metadata_context, discovered_tables)
|
|
702
|
+
"""
|
|
703
|
+
self.logger.debug(
|
|
704
|
+
f"🔍 DISCOVERY: Starting metadata discovery for query: '{query}'"
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
discovered_tables = []
|
|
708
|
+
metadata_parts = []
|
|
709
|
+
|
|
710
|
+
# Step 1: Direct schema search using table name extraction
|
|
711
|
+
table_name = self._extract_table_name_from_query(query)
|
|
712
|
+
|
|
713
|
+
if table_name:
|
|
714
|
+
self.logger.debug(
|
|
715
|
+
f"📋 Extracted table name: {table_name}"
|
|
716
|
+
)
|
|
717
|
+
# Search for exact table match first
|
|
718
|
+
for schema in self.allowed_schemas:
|
|
719
|
+
table_metadata = await self.metadata_cache.get_table_metadata(
|
|
720
|
+
schema,
|
|
721
|
+
table_name
|
|
722
|
+
)
|
|
723
|
+
if table_metadata:
|
|
724
|
+
self.logger.info(f"✅ EXACT MATCH: Found {schema}.{table_name}")
|
|
725
|
+
discovered_tables.append(table_metadata)
|
|
726
|
+
metadata_parts.append(table_metadata.to_yaml_context())
|
|
727
|
+
break
|
|
728
|
+
|
|
729
|
+
# Step 2: If no exact match, try more precise fuzzy search
|
|
730
|
+
if not discovered_tables and table_name:
|
|
731
|
+
self.logger.debug("🔄 No exact match, performing targeted fuzzy search...")
|
|
732
|
+
|
|
733
|
+
# Search specifically for the table name, not the entire query
|
|
734
|
+
similar_tables = await self.schema_tool.search_schema(
|
|
735
|
+
search_term=table_name, # Use ONLY the table name, not entire query
|
|
736
|
+
search_type="table_name", # Focus on table names only
|
|
737
|
+
limit=3 # Reduce limit to avoid noise
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
if similar_tables:
|
|
741
|
+
self.logger.info(f"🎯 FUZZY SEARCH: Found {len(similar_tables)} similar tables")
|
|
742
|
+
discovered_tables.extend(similar_tables)
|
|
743
|
+
for table in similar_tables:
|
|
744
|
+
metadata_parts.append(table.to_yaml_context())
|
|
745
|
+
else:
|
|
746
|
+
# If still no results, be explicit about missing table
|
|
747
|
+
self.logger.warning(
|
|
748
|
+
f"❌ TABLE NOT FOUND: '{table_name}' not found in any schema"
|
|
749
|
+
)
|
|
750
|
+
return self._create_table_not_found_response(table_name, query), []
|
|
751
|
+
|
|
752
|
+
# Step 3: Fallback to hot tables if still no results
|
|
753
|
+
if not discovered_tables:
|
|
754
|
+
self.logger.warning("⚠️ No specific tables found, using hot tables fallback")
|
|
755
|
+
hot_tables = self.metadata_cache.get_hot_tables(self.allowed_schemas, limit=3)
|
|
756
|
+
|
|
757
|
+
for schema_name, table_name, access_count in hot_tables:
|
|
758
|
+
table_meta = await self.metadata_cache.get_table_metadata(schema_name, table_name)
|
|
759
|
+
if table_meta:
|
|
760
|
+
discovered_tables.append(table_meta)
|
|
761
|
+
metadata_parts.append(table_meta.to_yaml_context())
|
|
762
|
+
|
|
763
|
+
# Combine metadata context
|
|
764
|
+
metadata_context = "\n---\n".join(metadata_parts) if metadata_parts else ""
|
|
765
|
+
|
|
766
|
+
if not metadata_context:
|
|
767
|
+
# Absolute fallback
|
|
768
|
+
metadata_context = f"Available schemas: {', '.join(self.allowed_schemas)} (primary: {self.primary_schema})"
|
|
769
|
+
self.logger.warning(
|
|
770
|
+
"⚠️ Using minimal fallback context"
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
self.logger.info(
|
|
774
|
+
f"🏁 DISCOVERY COMPLETE: {len(discovered_tables)} tables, "
|
|
775
|
+
f"context length: {len(metadata_context)} chars"
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
return metadata_context, discovered_tables
|
|
779
|
+
|
|
780
|
+
def _create_table_not_found_response(self, table_name: str, original_query: str) -> str:
|
|
781
|
+
"""Create a clear response when table doesn't exist."""
|
|
782
|
+
return f"""**Table Not Found**: `{table_name}`
|
|
783
|
+
The table `{table_name}` does not exist in any of the available schemas: {', '.join(self.allowed_schemas)}
|
|
784
|
+
|
|
785
|
+
**Available options:**
|
|
786
|
+
1. Check table name spelling
|
|
787
|
+
2. Use: "show tables" to list available tables
|
|
788
|
+
3. Search for similar tables: "find tables like {table_name[:5]}"
|
|
789
|
+
|
|
790
|
+
**Available schemas:** {', '.join([f'`{s}`' for s in self.allowed_schemas])}
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
def _extract_table_name_from_query(self, query: str) -> Optional[str]:
|
|
794
|
+
"""Extract table name with better precision."""
|
|
795
|
+
# Enhanced patterns with word boundaries and more specific matching
|
|
796
|
+
patterns = [
|
|
797
|
+
r'\bfrom\s+(?:[\w.]+\.)?(\w+)', # "from schema.table" or "from table"
|
|
798
|
+
r'\btable\s+(?:[\w.]+\.)?(\w+)', # "table schema.table" or "table name"
|
|
799
|
+
r'\bmetadata\s+of\s+(?:table\s+)?(?:[\w.]+\.)?(\w+)', # "metadata of table X"
|
|
800
|
+
r'\bdescribe\s+(?:table\s+)?(?:[\w.]+\.)?(\w+)', # "describe table X"
|
|
801
|
+
r'\bstructure\s+of\s+(?:[\w.]+\.)?(\w+)', # "structure of table"
|
|
802
|
+
r'\binformation\s+about\s+(?:[\w.]+\.)?(\w+)', # "information about table"
|
|
803
|
+
r'\bdetails\s+of\s+(?:[\w.]+\.)?(\w+)', # "details of table"
|
|
804
|
+
r'(?:[\w.]+\.)?(\w+)\s+table\b', # "inventory table"
|
|
805
|
+
# Be more specific about "records from" pattern
|
|
806
|
+
r'\brecords?\s+from\s+(?:[\w.]+\.)?(\w+)', # "records from table"
|
|
807
|
+
r'\bdata\s+from\s+(?:[\w.]+\.)?(\w+)', # "data from table"
|
|
808
|
+
]
|
|
809
|
+
|
|
810
|
+
query_lower = query.lower()
|
|
811
|
+
for pattern in patterns:
|
|
812
|
+
match = re.search(pattern, query_lower)
|
|
813
|
+
if match:
|
|
814
|
+
table_name = match.group(1)
|
|
815
|
+
# Filter out common false positives and SQL keywords
|
|
816
|
+
false_positives = {
|
|
817
|
+
'the', 'in', 'from', 'with', 'for', 'about', 'format',
|
|
818
|
+
'return', 'select', 'where', 'order', 'group', 'by',
|
|
819
|
+
'limit', 'offset', 'having', 'distinct'
|
|
820
|
+
}
|
|
821
|
+
if table_name not in false_positives:
|
|
822
|
+
self.logger.debug(f"📋 Extracted table name: '{table_name}' using pattern: {pattern}")
|
|
823
|
+
return table_name
|
|
824
|
+
|
|
825
|
+
return None
|
|
826
|
+
|
|
827
|
+
async def _generate_schema(
|
|
828
|
+
self,
|
|
829
|
+
query: str,
|
|
830
|
+
metadata_context: str,
|
|
831
|
+
schema_name: str
|
|
832
|
+
) -> str:
|
|
833
|
+
"""
|
|
834
|
+
Generate explanation for schema exploration queries.
|
|
835
|
+
|
|
836
|
+
Used when users ask about table metadata, schema structure, etc.
|
|
837
|
+
"""
|
|
838
|
+
|
|
839
|
+
# Extract table name if mentioned in query
|
|
840
|
+
table_name = self._extract_table_name_from_query(query)
|
|
841
|
+
|
|
842
|
+
if table_name:
|
|
843
|
+
# Get specific table metadata
|
|
844
|
+
table_metadata = await self.get_table_metadata(schema_name, table_name)
|
|
845
|
+
if table_metadata:
|
|
846
|
+
explanation = f"**Table: `{table_metadata.full_name}`**\n\n"
|
|
847
|
+
explanation += table_metadata.to_yaml_context()
|
|
848
|
+
return explanation
|
|
849
|
+
|
|
850
|
+
# General schema information
|
|
851
|
+
if metadata_context:
|
|
852
|
+
explanation = f"**Schema Information for `{schema_name}`:**\n\n"
|
|
853
|
+
explanation += metadata_context
|
|
854
|
+
return explanation
|
|
855
|
+
|
|
856
|
+
# Fallback
|
|
857
|
+
return f"Schema `{schema_name}` information. Use schema exploration tools for detailed structure."
|
|
858
|
+
|
|
859
|
+
async def _query_generation(
|
|
860
|
+
self,
|
|
861
|
+
query: str,
|
|
862
|
+
route: RouteDecision,
|
|
863
|
+
metadata_context: str,
|
|
864
|
+
context: Optional[str] = None,
|
|
865
|
+
**kwargs
|
|
866
|
+
) -> Tuple[str, str, AIMessage]:
|
|
867
|
+
"""Generate SQL query using LLM based on user request and metadata."""
|
|
868
|
+
self.logger.debug(
|
|
869
|
+
f"🔍 QUERY GEN: Generating SQL for intent '{route.intent.value}' "
|
|
870
|
+
f"with components {route.components}"
|
|
871
|
+
)
|
|
872
|
+
system_prompt = f"""
|
|
873
|
+
You are a PostgreSQL query expert for multi-schema databases.
|
|
874
|
+
|
|
875
|
+
**Database Context:**
|
|
876
|
+
**Primary Schema:** {self.primary_schema}
|
|
877
|
+
**Allowed Schemas:** {', '.join(self.allowed_schemas)}
|
|
878
|
+
|
|
879
|
+
**Context Information:**
|
|
880
|
+
{context}
|
|
881
|
+
|
|
882
|
+
**Available Tables and Structure:**
|
|
883
|
+
{metadata_context}
|
|
884
|
+
|
|
885
|
+
**Instructions:**
|
|
886
|
+
1. Generate PostgreSQL queries using only these schemas: {', '.join([f'"{schema}"' for schema in self.allowed_schemas])}
|
|
887
|
+
2. If you can generate a query using the available tables/columns, return ONLY the SQL query in a ```sql code block
|
|
888
|
+
3. NEVER invent table names - only use tables from the metadata above
|
|
889
|
+
4. If metadata is insufficient, use schema exploration tools
|
|
890
|
+
5. If you CANNOT generate a query (missing tables, columns, etc.), explain WHY in plain text - do NOT use code blocks
|
|
891
|
+
6. For "show me" queries, generate simple SELECT statements
|
|
892
|
+
7. Always include appropriate LIMIT clauses
|
|
893
|
+
8. Prefer primary schema "{self.primary_schema}" unless user specifies otherwise
|
|
894
|
+
|
|
895
|
+
**COLUMN SELECTION STRATEGY:**
|
|
896
|
+
1. First, look for EXACT matches to user terms
|
|
897
|
+
2. Then, look for SEMANTIC matches (price → pricing)
|
|
898
|
+
3. Choose the most appropriate column based on context
|
|
899
|
+
4. If multiple columns could work, prefer the most specific one
|
|
900
|
+
|
|
901
|
+
**QUERY PROCESSING RULES:**
|
|
902
|
+
1. ONLY use tables and columns from the metadata above - NEVER invent names
|
|
903
|
+
2. When user mentions concepts like "price", find the closest actual column name
|
|
904
|
+
3. Generate clean, readable PostgreSQL queries
|
|
905
|
+
4. Always include appropriate LIMIT clauses for "top N" requests
|
|
906
|
+
5. Use proper schema qualification: "{self.primary_schema}".table_name
|
|
907
|
+
|
|
908
|
+
**User Intent:** {route.intent.value}
|
|
909
|
+
|
|
910
|
+
Analyze the request and either generate a valid PostgreSQL query OR explain why it cannot be fulfilled.
|
|
911
|
+
Apply semantic understanding to map user concepts to available columns.
|
|
912
|
+
|
|
913
|
+
**Your Task:** Analyze the user request and provide either a SQL query OR a clear explanation.
|
|
914
|
+
"""
|
|
915
|
+
# Call LLM for query generation
|
|
916
|
+
async with self._llm as client:
|
|
917
|
+
llm_response = await client.ask(
|
|
918
|
+
prompt=f"User request: {query}",
|
|
919
|
+
system_prompt=system_prompt,
|
|
920
|
+
**kwargs
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
# Extract SQL and explanation
|
|
924
|
+
response_text = str(llm_response.output) if llm_response.output else str(llm_response.response)
|
|
925
|
+
# 🔍 DEBUG: Log what LLM actually said
|
|
926
|
+
self.logger.info(f"🤖 LLM RESPONSE: {response_text[:200]}...")
|
|
927
|
+
sql_query = self._extract_sql_from_response(response_text)
|
|
928
|
+
|
|
929
|
+
if not sql_query:
|
|
930
|
+
if self._is_explanatory_response(response_text):
|
|
931
|
+
self.logger.info(f"🔍 LLM PROVIDED EXPLANATION: No SQL generated, but explanation available")
|
|
932
|
+
return None, response_text, llm_response
|
|
933
|
+
else: # ← FIX: Move the else inside the if not sql_query block
|
|
934
|
+
self.logger.warning(f"🔍 LLM RESPONSE UNCLEAR: No SQL found and doesn't look like explanation")
|
|
935
|
+
|
|
936
|
+
return sql_query, response_text, llm_response
|
|
937
|
+
|
|
938
|
+
def _is_explanatory_response(self, response_text: str) -> bool:
|
|
939
|
+
"""Detect if the LLM response is an explanation rather than SQL."""
|
|
940
|
+
|
|
941
|
+
# Clean the response for analysis
|
|
942
|
+
cleaned_text = response_text.strip().lower()
|
|
943
|
+
|
|
944
|
+
# Patterns that indicate explanatory responses
|
|
945
|
+
explanation_patterns = [
|
|
946
|
+
"i cannot",
|
|
947
|
+
"i'm sorry",
|
|
948
|
+
"i am sorry",
|
|
949
|
+
"unable to",
|
|
950
|
+
"cannot fulfill",
|
|
951
|
+
"cannot generate",
|
|
952
|
+
"cannot create",
|
|
953
|
+
"the table",
|
|
954
|
+
"the metadata",
|
|
955
|
+
"does not contain",
|
|
956
|
+
"missing",
|
|
957
|
+
"not found",
|
|
958
|
+
"no table",
|
|
959
|
+
"no column",
|
|
960
|
+
"not available",
|
|
961
|
+
"insufficient information",
|
|
962
|
+
"please provide",
|
|
963
|
+
"you need to"
|
|
964
|
+
]
|
|
965
|
+
|
|
966
|
+
# Check if response contains explanatory language
|
|
967
|
+
contains_explanation = any(pattern in cleaned_text for pattern in explanation_patterns)
|
|
968
|
+
|
|
969
|
+
# Check if response lacks SQL patterns
|
|
970
|
+
sql_patterns = ['select', 'from', 'where', 'order by', 'group by', 'insert', 'update', 'delete']
|
|
971
|
+
contains_sql = any(pattern in cleaned_text for pattern in sql_patterns)
|
|
972
|
+
|
|
973
|
+
# It's explanatory if it has explanation patterns but no SQL
|
|
974
|
+
is_explanatory = contains_explanation and not contains_sql
|
|
975
|
+
|
|
976
|
+
self.logger.debug(
|
|
977
|
+
f"🔍 EXPLANATION CHECK: explanation_patterns={contains_explanation}, sql_patterns={contains_sql}, is_explanatory={is_explanatory}"
|
|
978
|
+
)
|
|
979
|
+
return is_explanatory
|
|
980
|
+
|
|
981
|
+
async def _execute_query_explain(
|
|
982
|
+
self,
|
|
983
|
+
sql_query: str,
|
|
984
|
+
) -> QueryExecutionResponse:
|
|
985
|
+
"""Execute query with EXPLAIN ANALYZE for performance analysis."""
|
|
986
|
+
|
|
987
|
+
start_time = datetime.now()
|
|
988
|
+
|
|
989
|
+
try:
|
|
990
|
+
async with self.session_maker() as session:
|
|
991
|
+
# Set search path for security
|
|
992
|
+
search_path = ','.join(self.allowed_schemas)
|
|
993
|
+
await session.execute(text(f"SET search_path = '{search_path}'"))
|
|
994
|
+
|
|
995
|
+
# Execute EXPLAIN ANALYZE first
|
|
996
|
+
explain_query = f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {sql_query}"
|
|
997
|
+
|
|
998
|
+
try:
|
|
999
|
+
plan_result = await session.execute(text(explain_query))
|
|
1000
|
+
query_plan_json = plan_result.fetchone()[0]
|
|
1001
|
+
|
|
1002
|
+
# Convert JSON plan to readable format
|
|
1003
|
+
query_plan = self._format_explain_plan(query_plan_json)
|
|
1004
|
+
|
|
1005
|
+
print('FORMAT PLAN > ', query_plan)
|
|
1006
|
+
|
|
1007
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
1008
|
+
|
|
1009
|
+
return QueryExecutionResponse(
|
|
1010
|
+
success=True,
|
|
1011
|
+
data=None, # EXPLAIN doesn't return data
|
|
1012
|
+
row_count=0,
|
|
1013
|
+
execution_time_ms=execution_time,
|
|
1014
|
+
query_plan=query_plan,
|
|
1015
|
+
schema_used=self.primary_schema,
|
|
1016
|
+
metadata={
|
|
1017
|
+
"plan_json": query_plan_json, # Store JSON for metrics extraction
|
|
1018
|
+
"query_type": "explain_analyze"
|
|
1019
|
+
}
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
# If EXPLAIN fails, the query has syntax/table issues
|
|
1024
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
1025
|
+
|
|
1026
|
+
return QueryExecutionResponse(
|
|
1027
|
+
success=False,
|
|
1028
|
+
data=None,
|
|
1029
|
+
row_count=0,
|
|
1030
|
+
execution_time_ms=execution_time,
|
|
1031
|
+
error_message=str(e),
|
|
1032
|
+
schema_used=self.primary_schema
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
except Exception as e:
|
|
1036
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
1037
|
+
|
|
1038
|
+
return QueryExecutionResponse(
|
|
1039
|
+
success=False,
|
|
1040
|
+
data=None,
|
|
1041
|
+
row_count=0,
|
|
1042
|
+
execution_time_ms=execution_time,
|
|
1043
|
+
error_message=f"Database connection error: {str(e)}",
|
|
1044
|
+
schema_used=self.primary_schema
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
def _format_explain_plan(self, plan_json) -> str:
|
|
1048
|
+
"""Format EXPLAIN ANALYZE JSON output to comprehensive readable text for developers."""
|
|
1049
|
+
if not plan_json or not isinstance(plan_json, list):
|
|
1050
|
+
return "No execution plan available"
|
|
1051
|
+
|
|
1052
|
+
try:
|
|
1053
|
+
plan_data = plan_json[0]
|
|
1054
|
+
main_plan = plan_data.get("Plan", {})
|
|
1055
|
+
|
|
1056
|
+
# Build comprehensive formatted output
|
|
1057
|
+
lines = []
|
|
1058
|
+
|
|
1059
|
+
# Header with overall timing
|
|
1060
|
+
lines.append("=" * 60)
|
|
1061
|
+
lines.append("POSTGRESQL EXECUTION PLAN ANALYSIS")
|
|
1062
|
+
lines.append("=" * 60)
|
|
1063
|
+
|
|
1064
|
+
# Overall execution statistics
|
|
1065
|
+
if "Execution Time" in plan_data:
|
|
1066
|
+
lines.append(f"📊 **Overall Execution Time:** {plan_data['Execution Time']:.3f}ms")
|
|
1067
|
+
if "Planning Time" in plan_data:
|
|
1068
|
+
lines.append(f"🧠 **Planning Time:** {plan_data['Planning Time']:.3f}ms")
|
|
1069
|
+
|
|
1070
|
+
lines.append("")
|
|
1071
|
+
lines.append("🔍 **Detailed Node Analysis:**")
|
|
1072
|
+
lines.append("-" * 40)
|
|
1073
|
+
|
|
1074
|
+
def format_node_detailed(node, level=0):
|
|
1075
|
+
indent = " " * level
|
|
1076
|
+
node_type = node.get("Node Type", "Unknown")
|
|
1077
|
+
node_lines = []
|
|
1078
|
+
|
|
1079
|
+
# Main node header
|
|
1080
|
+
node_lines.append(f"{indent}{'└─' if level > 0 else '▶'} **{node_type}**")
|
|
1081
|
+
|
|
1082
|
+
# Cost analysis (startup vs total)
|
|
1083
|
+
startup_cost = node.get("Startup Cost", 0)
|
|
1084
|
+
total_cost = node.get("Total Cost", 0)
|
|
1085
|
+
if startup_cost or total_cost:
|
|
1086
|
+
node_lines.append(f"{indent} 💰 Cost: {startup_cost:.2f}..{total_cost:.2f}")
|
|
1087
|
+
|
|
1088
|
+
# Timing details (startup vs total)
|
|
1089
|
+
startup_time = node.get("Actual Startup Time")
|
|
1090
|
+
total_time = node.get("Actual Total Time")
|
|
1091
|
+
if startup_time is not None and total_time is not None:
|
|
1092
|
+
node_lines.append(f"{indent} ⏱️ Time: {startup_time:.3f}ms..{total_time:.3f}ms")
|
|
1093
|
+
|
|
1094
|
+
# Row estimates vs actual
|
|
1095
|
+
plan_rows = node.get("Plan Rows")
|
|
1096
|
+
actual_rows = node.get("Actual Rows")
|
|
1097
|
+
if plan_rows is not None or actual_rows is not None:
|
|
1098
|
+
estimate_accuracy = ""
|
|
1099
|
+
if plan_rows and actual_rows:
|
|
1100
|
+
ratio = actual_rows / plan_rows if plan_rows > 0 else float('inf')
|
|
1101
|
+
if ratio > 2 or ratio < 0.5:
|
|
1102
|
+
estimate_accuracy = f" ⚠️ {'Over' if ratio > 1 else 'Under'}estimated by {ratio:.1f}x"
|
|
1103
|
+
|
|
1104
|
+
node_lines.append(f"{indent} 📊 Rows: {plan_rows or 'N/A'} planned → {actual_rows or 'N/A'} actual{estimate_accuracy}")
|
|
1105
|
+
|
|
1106
|
+
# Loop information
|
|
1107
|
+
loops = node.get("Actual Loops", 1)
|
|
1108
|
+
if loops > 1:
|
|
1109
|
+
node_lines.append(f"{indent} 🔄 Loops: {loops}")
|
|
1110
|
+
|
|
1111
|
+
# Width (average row size)
|
|
1112
|
+
if "Plan Width" in node:
|
|
1113
|
+
node_lines.append(f"{indent} 📏 Avg Row Size: {node['Plan Width']} bytes")
|
|
1114
|
+
|
|
1115
|
+
# Table/relation information
|
|
1116
|
+
if "Relation Name" in node:
|
|
1117
|
+
table_info = f"{indent} 🗂️ Table: {node['Relation Name']}"
|
|
1118
|
+
if "Alias" in node and node["Alias"] != node["Relation Name"]:
|
|
1119
|
+
table_info += f" (as {node['Alias']})"
|
|
1120
|
+
node_lines.append(table_info)
|
|
1121
|
+
|
|
1122
|
+
# Index information
|
|
1123
|
+
if "Index Name" in node:
|
|
1124
|
+
index_info = f"{indent} 🔑 Index: {node['Index Name']}"
|
|
1125
|
+
if "Scan Direction" in node:
|
|
1126
|
+
index_info += f" ({node['Scan Direction']} scan)"
|
|
1127
|
+
node_lines.append(index_info)
|
|
1128
|
+
|
|
1129
|
+
# Join/Filter conditions
|
|
1130
|
+
if "Hash Cond" in node:
|
|
1131
|
+
node_lines.append(f"{indent} 🔗 Hash Condition: {node['Hash Cond']}")
|
|
1132
|
+
if "Index Cond" in node:
|
|
1133
|
+
node_lines.append(f"{indent} 🎯 Index Condition: {node['Index Cond']}")
|
|
1134
|
+
if "Filter" in node:
|
|
1135
|
+
node_lines.append(f"{indent} 🔍 Filter: {node['Filter']}")
|
|
1136
|
+
if "Rows Removed by Filter" in node:
|
|
1137
|
+
node_lines.append(f"{indent} ❌ Filtered out: {node['Rows Removed by Filter']} rows")
|
|
1138
|
+
|
|
1139
|
+
# Sort information
|
|
1140
|
+
if "Sort Key" in node:
|
|
1141
|
+
node_lines.append(f"{indent} 🔤 Sort Key: {', '.join(node['Sort Key'])}")
|
|
1142
|
+
if "Sort Method" in node:
|
|
1143
|
+
sort_info = f"{indent} 📈 Sort Method: {node['Sort Method']}"
|
|
1144
|
+
if "Sort Space Used" in node and "Sort Space Type" in node:
|
|
1145
|
+
sort_info += f" ({node['Sort Space Used']}kB {node['Sort Space Type']})"
|
|
1146
|
+
node_lines.append(sort_info)
|
|
1147
|
+
|
|
1148
|
+
# Buffer usage (I/O statistics)
|
|
1149
|
+
buffer_info = []
|
|
1150
|
+
buffer_fields = [
|
|
1151
|
+
("Shared Hit Blocks", "💾 Shared Hit"),
|
|
1152
|
+
("Shared Read Blocks", "💿 Shared Read"),
|
|
1153
|
+
("Shared Dirtied Blocks", "✏️ Shared Dirty"),
|
|
1154
|
+
("Shared Written Blocks", "💾 Shared Write"),
|
|
1155
|
+
("Temp Read Blocks", "🌡️ Temp Read"),
|
|
1156
|
+
("Temp Written Blocks", "🌡️ Temp Write")
|
|
1157
|
+
]
|
|
1158
|
+
|
|
1159
|
+
for field, label in buffer_fields:
|
|
1160
|
+
if field in node and node[field] > 0:
|
|
1161
|
+
buffer_info.append(f"{label}: {node[field]}")
|
|
1162
|
+
|
|
1163
|
+
if buffer_info:
|
|
1164
|
+
node_lines.append(f"{indent} 📊 Buffers: {' | '.join(buffer_info)}")
|
|
1165
|
+
|
|
1166
|
+
# Parallelism information
|
|
1167
|
+
if node.get("Parallel Aware") and "Workers Planned" in node:
|
|
1168
|
+
parallel_info = f"{indent} ⚡ Parallel: {node.get('Workers Planned', 0)} workers planned"
|
|
1169
|
+
if "Workers Launched" in node:
|
|
1170
|
+
parallel_info += f", {node['Workers Launched']} launched"
|
|
1171
|
+
node_lines.append(parallel_info)
|
|
1172
|
+
|
|
1173
|
+
# Memory usage
|
|
1174
|
+
if "Hash Buckets" in node:
|
|
1175
|
+
memory_info = f"{indent} 🧠 Hash: {node['Hash Buckets']} buckets"
|
|
1176
|
+
if "Hash Batches" in node:
|
|
1177
|
+
memory_info += f", {node['Hash Batches']} batches"
|
|
1178
|
+
if "Peak Memory Usage" in node:
|
|
1179
|
+
memory_info += f", {node['Peak Memory Usage']}kB peak"
|
|
1180
|
+
node_lines.append(memory_info)
|
|
1181
|
+
|
|
1182
|
+
# Add blank line after each major node
|
|
1183
|
+
node_lines.append("")
|
|
1184
|
+
|
|
1185
|
+
# Process child nodes recursively
|
|
1186
|
+
if "Plans" in node and node["Plans"]:
|
|
1187
|
+
for child in node["Plans"]:
|
|
1188
|
+
node_lines.extend(format_node_detailed(child, level + 1))
|
|
1189
|
+
|
|
1190
|
+
return node_lines
|
|
1191
|
+
|
|
1192
|
+
# Format the main execution tree
|
|
1193
|
+
formatted_lines = format_node_detailed(main_plan)
|
|
1194
|
+
lines.extend(formatted_lines)
|
|
1195
|
+
|
|
1196
|
+
# Overall statistics summary
|
|
1197
|
+
lines.append("=" * 40)
|
|
1198
|
+
lines.append("📈 **EXECUTION SUMMARY**")
|
|
1199
|
+
lines.append("=" * 40)
|
|
1200
|
+
|
|
1201
|
+
def extract_totals(node, totals=None):
|
|
1202
|
+
if totals is None:
|
|
1203
|
+
totals = {
|
|
1204
|
+
'total_cost': 0,
|
|
1205
|
+
'total_time': 0,
|
|
1206
|
+
'total_rows': 0,
|
|
1207
|
+
'seq_scans': 0,
|
|
1208
|
+
'index_scans': 0,
|
|
1209
|
+
'sorts': 0,
|
|
1210
|
+
'joins': 0
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
# Accumulate costs and times
|
|
1214
|
+
totals['total_cost'] += node.get('Total Cost', 0)
|
|
1215
|
+
totals['total_time'] += node.get('Actual Total Time', 0)
|
|
1216
|
+
totals['total_rows'] += node.get('Actual Rows', 0)
|
|
1217
|
+
|
|
1218
|
+
# Count operation types
|
|
1219
|
+
node_type = node.get('Node Type', '').lower()
|
|
1220
|
+
if 'seq scan' in node_type:
|
|
1221
|
+
totals['seq_scans'] += 1
|
|
1222
|
+
elif 'index scan' in node_type or 'index only scan' in node_type:
|
|
1223
|
+
totals['index_scans'] += 1
|
|
1224
|
+
elif 'sort' in node_type:
|
|
1225
|
+
totals['sorts'] += 1
|
|
1226
|
+
elif 'join' in node_type:
|
|
1227
|
+
totals['joins'] += 1
|
|
1228
|
+
|
|
1229
|
+
# Recurse into child plans
|
|
1230
|
+
if 'Plans' in node:
|
|
1231
|
+
for child in node['Plans']:
|
|
1232
|
+
extract_totals(child, totals)
|
|
1233
|
+
|
|
1234
|
+
return totals
|
|
1235
|
+
|
|
1236
|
+
totals = extract_totals(main_plan)
|
|
1237
|
+
|
|
1238
|
+
lines.append(f"• Total Estimated Cost: {totals['total_cost']:.2f}")
|
|
1239
|
+
lines.append(f"• Sequential Scans: {totals['seq_scans']}")
|
|
1240
|
+
lines.append(f"• Index Scans: {totals['index_scans']}")
|
|
1241
|
+
lines.append(f"• Sort Operations: {totals['sorts']}")
|
|
1242
|
+
lines.append(f"• Join Operations: {totals['joins']}")
|
|
1243
|
+
|
|
1244
|
+
# Performance indicators
|
|
1245
|
+
lines.append("\n🎯 **PERFORMANCE INDICATORS:**")
|
|
1246
|
+
performance_notes = []
|
|
1247
|
+
|
|
1248
|
+
if totals['seq_scans'] > 0:
|
|
1249
|
+
performance_notes.append("⚠️ Sequential scans detected - consider indexing")
|
|
1250
|
+
if totals['sorts'] > 1:
|
|
1251
|
+
performance_notes.append("📈 Multiple sorts - check ORDER BY optimization")
|
|
1252
|
+
if totals['total_cost'] > 1000:
|
|
1253
|
+
performance_notes.append("💰 High cost query - review for optimization opportunities")
|
|
1254
|
+
|
|
1255
|
+
if performance_notes:
|
|
1256
|
+
lines.extend([f"• {note}" for note in performance_notes])
|
|
1257
|
+
else:
|
|
1258
|
+
lines.append("• ✅ No obvious performance issues detected")
|
|
1259
|
+
|
|
1260
|
+
return "\n".join(lines)
|
|
1261
|
+
|
|
1262
|
+
except Exception as e:
|
|
1263
|
+
return f"Error formatting execution plan: {str(e)}\n\nRaw JSON: {str(plan_json)[:500]}..."
|
|
1264
|
+
|
|
1265
|
+
async def _generate_query(
|
|
1266
|
+
self,
|
|
1267
|
+
query: str,
|
|
1268
|
+
route: RouteDecision,
|
|
1269
|
+
metadata_context: str,
|
|
1270
|
+
conversation_context: str,
|
|
1271
|
+
vector_context: str,
|
|
1272
|
+
user_context: Optional[str],
|
|
1273
|
+
context: Optional[str] = None,
|
|
1274
|
+
**kwargs
|
|
1275
|
+
) -> Tuple[Optional[str], Optional[str], Optional[AIMessage]]:
|
|
1276
|
+
"""
|
|
1277
|
+
Generate SQL query based on user request and context.
|
|
1278
|
+
|
|
1279
|
+
Adapts the existing _process_query_generation method to work with components.
|
|
1280
|
+
"""
|
|
1281
|
+
|
|
1282
|
+
# For schema exploration, don't generate SQL - use schema tools
|
|
1283
|
+
if route.intent.value in ['explore_schema', 'explain_metadata']:
|
|
1284
|
+
explanation = await self._generate_schema(
|
|
1285
|
+
query, metadata_context, route.primary_schema
|
|
1286
|
+
)
|
|
1287
|
+
return None, explanation, None
|
|
1288
|
+
|
|
1289
|
+
elif route.intent.value == 'validate_query':
|
|
1290
|
+
# User provided SQL, validate it
|
|
1291
|
+
sql_query = query.strip()
|
|
1292
|
+
explanation, llm_response = await self._validate_user_sql(
|
|
1293
|
+
sql_query=sql_query,
|
|
1294
|
+
metadata_context=metadata_context,
|
|
1295
|
+
context=context
|
|
1296
|
+
)
|
|
1297
|
+
return sql_query, explanation, llm_response
|
|
1298
|
+
|
|
1299
|
+
else:
|
|
1300
|
+
# Generate new SQL query using the EXISTING method from your code
|
|
1301
|
+
sql_query, explanation, llm_response = await self._query_generation(
|
|
1302
|
+
query=query,
|
|
1303
|
+
route=route,
|
|
1304
|
+
metadata_context=metadata_context,
|
|
1305
|
+
context=context,
|
|
1306
|
+
**kwargs
|
|
1307
|
+
)
|
|
1308
|
+
return sql_query, explanation, llm_response
|
|
1309
|
+
|
|
1310
|
+
async def _process_query(
|
|
1311
|
+
self,
|
|
1312
|
+
query: str,
|
|
1313
|
+
route: RouteDecision,
|
|
1314
|
+
metadata_context: str,
|
|
1315
|
+
discovered_tables: List[TableMetadata],
|
|
1316
|
+
conversation_context: str,
|
|
1317
|
+
vector_context: str,
|
|
1318
|
+
user_context: Optional[str],
|
|
1319
|
+
enable_retry: bool,
|
|
1320
|
+
retry_config: Optional[QueryRetryConfig] = None,
|
|
1321
|
+
context: Optional[str] = None,
|
|
1322
|
+
**kwargs
|
|
1323
|
+
) -> Tuple[DatabaseResponse, AIMessage]:
|
|
1324
|
+
"""Process query generation with LLM."""
|
|
1325
|
+
|
|
1326
|
+
db_response = DatabaseResponse(components_included=route.components)
|
|
1327
|
+
llm_response = None
|
|
1328
|
+
|
|
1329
|
+
is_documentation_request = (
|
|
1330
|
+
'metadata' in query.lower() or
|
|
1331
|
+
'documentation' in query.lower() or
|
|
1332
|
+
'describe' in query.lower() or
|
|
1333
|
+
'structure' in query.lower() or
|
|
1334
|
+
route.intent in [QueryIntent.EXPLAIN_METADATA, QueryIntent.EXPLORE_SCHEMA] and
|
|
1335
|
+
route.user_role != UserRole.QUERY_DEVELOPER
|
|
1336
|
+
)
|
|
1337
|
+
if is_documentation_request:
|
|
1338
|
+
db_response.is_documentation = True
|
|
1339
|
+
|
|
1340
|
+
if route.user_role == UserRole.QUERY_DEVELOPER:
|
|
1341
|
+
# Developers always get raw SQL and data results
|
|
1342
|
+
if OutputComponent.SQL_QUERY in route.components:
|
|
1343
|
+
sql_query, explanation, llm_response = await self._generate_query(
|
|
1344
|
+
query=query,
|
|
1345
|
+
route=route,
|
|
1346
|
+
metadata_context=metadata_context,
|
|
1347
|
+
conversation_context=conversation_context,
|
|
1348
|
+
vector_context=vector_context,
|
|
1349
|
+
user_context=user_context,
|
|
1350
|
+
context=context,
|
|
1351
|
+
**kwargs
|
|
1352
|
+
)
|
|
1353
|
+
db_response.query = sql_query
|
|
1354
|
+
# Store the generation attempt explanation
|
|
1355
|
+
if explanation:
|
|
1356
|
+
db_response.documentation = explanation
|
|
1357
|
+
|
|
1358
|
+
if db_response.query and OutputComponent.EXECUTION_PLAN in route.components:
|
|
1359
|
+
self.logger.info(
|
|
1360
|
+
f"🔧 QUERY_DEVELOPER: Attempting execution with EXPLAIN ANALYZE"
|
|
1361
|
+
)
|
|
1362
|
+
# Try to execute with EXPLAIN ANALYZE
|
|
1363
|
+
exec_result = await self._execute_query_explain(db_response.query)
|
|
1364
|
+
|
|
1365
|
+
if exec_result.success:
|
|
1366
|
+
# Extract execution plan
|
|
1367
|
+
if exec_result.query_plan:
|
|
1368
|
+
db_response.execution_plan = exec_result.query_plan
|
|
1369
|
+
|
|
1370
|
+
# Extract JSON plan data from metadata
|
|
1371
|
+
plan_json = None
|
|
1372
|
+
if exec_result.metadata and "plan_json" in exec_result.metadata:
|
|
1373
|
+
plan_json = exec_result.metadata["plan_json"]
|
|
1374
|
+
|
|
1375
|
+
# Extract performance metrics with JSON data
|
|
1376
|
+
if OutputComponent.PERFORMANCE_METRICS in route.components:
|
|
1377
|
+
db_response.performance_metrics = self._extract_performance_metrics(
|
|
1378
|
+
exec_result.query_plan,
|
|
1379
|
+
exec_result.execution_time_ms,
|
|
1380
|
+
plan_json=plan_json # Pass JSON data
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
# Generate optimization tips with JSON data
|
|
1384
|
+
if OutputComponent.OPTIMIZATION_TIPS in route.components:
|
|
1385
|
+
db_response.optimization_tips, opt_llm_response = await self._generate_optimization_tips(
|
|
1386
|
+
sql_query=db_response.query,
|
|
1387
|
+
query_plan=exec_result.query_plan,
|
|
1388
|
+
metadata_context=metadata_context,
|
|
1389
|
+
context=context,
|
|
1390
|
+
plan_json=plan_json # Pass JSON data
|
|
1391
|
+
)
|
|
1392
|
+
if opt_llm_response and not llm_response:
|
|
1393
|
+
llm_response = opt_llm_response
|
|
1394
|
+
else:
|
|
1395
|
+
# Query failed - provide analysis of why it failed
|
|
1396
|
+
db_response.documentation = f"""**Query Execution Failed**
|
|
1397
|
+
**Generated SQL:**
|
|
1398
|
+
```sql
|
|
1399
|
+
{db_response.query}
|
|
1400
|
+
```
|
|
1401
|
+
|
|
1402
|
+
**Error:** {exec_result.error_message}
|
|
1403
|
+
|
|
1404
|
+
**Analysis:** The query could not be executed. This could be due to:
|
|
1405
|
+
- Table/column name issues
|
|
1406
|
+
- Syntax errors
|
|
1407
|
+
- Permission problems
|
|
1408
|
+
- Schema access restrictions
|
|
1409
|
+
|
|
1410
|
+
**Recommendations:**
|
|
1411
|
+
1. Verify the table exists in the specified schema
|
|
1412
|
+
2. Check column names and data types
|
|
1413
|
+
3. Ensure proper schema permissions
|
|
1414
|
+
"""
|
|
1415
|
+
# Still provide basic optimization tips for the failed query
|
|
1416
|
+
if OutputComponent.OPTIMIZATION_TIPS in route.components:
|
|
1417
|
+
db_response.optimization_tips = [
|
|
1418
|
+
"🔍 Verify table name exists in available schemas",
|
|
1419
|
+
"📋 Use 'SHOW TABLES' to list available tables",
|
|
1420
|
+
"🔧 Check table name spelling and case sensitivity",
|
|
1421
|
+
"📊 Ensure proper schema permissions are granted"
|
|
1422
|
+
]
|
|
1423
|
+
|
|
1424
|
+
# Always provide schema context for QUERY_DEVELOPER
|
|
1425
|
+
if OutputComponent.SCHEMA_CONTEXT in route.components:
|
|
1426
|
+
db_response.schema_context = await self._build_schema_context(
|
|
1427
|
+
route.primary_schema,
|
|
1428
|
+
route.allowed_schemas,
|
|
1429
|
+
discovered_tables=discovered_tables
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
return db_response, llm_response
|
|
1433
|
+
|
|
1434
|
+
# Generate SQL query (if needed)
|
|
1435
|
+
if route.needs_query_generation and OutputComponent.SQL_QUERY in route.components:
|
|
1436
|
+
sql_query, explanation, llm_response = await self._generate_query(
|
|
1437
|
+
query=query,
|
|
1438
|
+
route=route,
|
|
1439
|
+
metadata_context=metadata_context,
|
|
1440
|
+
conversation_context=conversation_context,
|
|
1441
|
+
vector_context=vector_context,
|
|
1442
|
+
user_context=user_context,
|
|
1443
|
+
context=context,
|
|
1444
|
+
**kwargs
|
|
1445
|
+
)
|
|
1446
|
+
db_response.query = sql_query
|
|
1447
|
+
|
|
1448
|
+
# Store explanation for documentation component
|
|
1449
|
+
if OutputComponent.DOCUMENTATION in route.components:
|
|
1450
|
+
db_response.documentation = explanation
|
|
1451
|
+
|
|
1452
|
+
# Execute query (if needed)
|
|
1453
|
+
if route.needs_execution and db_response.query:
|
|
1454
|
+
exec_result = await self._execute_query(
|
|
1455
|
+
db_response.query, route, enable_retry, retry_config
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
if exec_result.success:
|
|
1459
|
+
# Handle data conversion based on components
|
|
1460
|
+
if OutputComponent.DATAFRAME_OUTPUT in route.components:
|
|
1461
|
+
if exec_result.data:
|
|
1462
|
+
db_response.data = pd.DataFrame(exec_result.data)
|
|
1463
|
+
elif OutputComponent.DATA_RESULTS in route.components:
|
|
1464
|
+
db_response.data = exec_result.data
|
|
1465
|
+
|
|
1466
|
+
db_response.row_count = exec_result.row_count
|
|
1467
|
+
db_response.execution_time_ms = exec_result.execution_time_ms
|
|
1468
|
+
|
|
1469
|
+
# Sample data for context
|
|
1470
|
+
if OutputComponent.SAMPLE_DATA in route.components and exec_result.data:
|
|
1471
|
+
db_response.sample_data = exec_result.data[:5] # First 5 rows
|
|
1472
|
+
|
|
1473
|
+
# Execution plan analysis
|
|
1474
|
+
if exec_result.query_plan and OutputComponent.EXECUTION_PLAN in route.components:
|
|
1475
|
+
db_response.execution_plan = exec_result.query_plan
|
|
1476
|
+
|
|
1477
|
+
# Generate performance metrics
|
|
1478
|
+
if OutputComponent.PERFORMANCE_METRICS in route.components:
|
|
1479
|
+
db_response.performance_metrics = self._extract_performance_metrics(
|
|
1480
|
+
exec_result.query_plan, exec_result.execution_time_ms
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
# Generate LLM-based optimization tips
|
|
1484
|
+
if OutputComponent.OPTIMIZATION_TIPS in route.components:
|
|
1485
|
+
db_response.optimization_tips, llm_response = await self._generate_optimization_tips(
|
|
1486
|
+
sql_query=db_response.query,
|
|
1487
|
+
query_plan=exec_result.query_plan,
|
|
1488
|
+
metadata_context=metadata_context,
|
|
1489
|
+
context=context
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1492
|
+
# For documentation requests, format discovered table metadata instead of examples
|
|
1493
|
+
if (OutputComponent.DOCUMENTATION in route.components or is_documentation_request) and \
|
|
1494
|
+
route.user_role != UserRole.QUERY_DEVELOPER:
|
|
1495
|
+
if discovered_tables:
|
|
1496
|
+
# Generate detailed documentation for discovered tables
|
|
1497
|
+
db_response.documentation = await self._format_table_documentation(
|
|
1498
|
+
discovered_tables, route.user_role, query
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
# Generate examples only if NOT a documentation request
|
|
1502
|
+
if OutputComponent.EXAMPLES in route.components and not is_documentation_request and \
|
|
1503
|
+
route.user_role != UserRole.QUERY_DEVELOPER:
|
|
1504
|
+
db_response.examples = await self._generate_examples(
|
|
1505
|
+
query, metadata_context, discovered_tables, route.primary_schema
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
# Schema context (if requested)
|
|
1509
|
+
if OutputComponent.SCHEMA_CONTEXT in route.components:
|
|
1510
|
+
db_response.schema_context = await self._build_schema_context(
|
|
1511
|
+
route.primary_schema,
|
|
1512
|
+
route.allowed_schemas,
|
|
1513
|
+
discovered_tables=discovered_tables
|
|
1514
|
+
)
|
|
1515
|
+
|
|
1516
|
+
return db_response, llm_response
|
|
1517
|
+
|
|
1518
|
+
async def _format_table_documentation(
|
|
1519
|
+
self,
|
|
1520
|
+
discovered_tables: List[TableMetadata],
|
|
1521
|
+
user_role: UserRole,
|
|
1522
|
+
original_query: str
|
|
1523
|
+
) -> str:
|
|
1524
|
+
"""
|
|
1525
|
+
Format discovered table metadata as proper documentation.
|
|
1526
|
+
|
|
1527
|
+
This replaces the generic examples with actual table documentation.
|
|
1528
|
+
"""
|
|
1529
|
+
if not discovered_tables:
|
|
1530
|
+
return "No table metadata found for documentation."
|
|
1531
|
+
|
|
1532
|
+
documentation_parts = []
|
|
1533
|
+
|
|
1534
|
+
for table in discovered_tables:
|
|
1535
|
+
# Table header
|
|
1536
|
+
table_doc = [f"# Table: `{table.full_name}`\n"]
|
|
1537
|
+
|
|
1538
|
+
# Table information
|
|
1539
|
+
if table.comment:
|
|
1540
|
+
table_doc.append(f"**Description:** {table.comment}\n")
|
|
1541
|
+
|
|
1542
|
+
table_doc.append(f"**Schema:** {table.schema}")
|
|
1543
|
+
table_doc.append(f"**Table Name:** {table.tablename}")
|
|
1544
|
+
table_doc.append(f"**Type:** {table.table_type}")
|
|
1545
|
+
table_doc.append(f"**Row Count:** {table.row_count:,}" if table.row_count else "**Row Count:** Unknown")
|
|
1546
|
+
|
|
1547
|
+
# Column documentation
|
|
1548
|
+
if table.columns:
|
|
1549
|
+
table_doc.append("\n## Columns\n")
|
|
1550
|
+
|
|
1551
|
+
# Create markdown table for columns
|
|
1552
|
+
table_doc.append("| Column Name | Data Type | Nullable | Default | Comment |")
|
|
1553
|
+
table_doc.append("|-------------|-----------|----------|---------|---------|")
|
|
1554
|
+
|
|
1555
|
+
for col in table.columns:
|
|
1556
|
+
nullable = "Yes" if col.get('nullable', True) else "No"
|
|
1557
|
+
default_val = col.get('default', '') or ''
|
|
1558
|
+
comment = col.get('comment', '') or ''
|
|
1559
|
+
data_type = col.get('type', 'unknown')
|
|
1560
|
+
|
|
1561
|
+
# Handle max_length for varchar types
|
|
1562
|
+
if col.get('max_length') and 'character' in data_type.lower():
|
|
1563
|
+
data_type = f"{data_type}({col['max_length']})"
|
|
1564
|
+
|
|
1565
|
+
table_doc.append(
|
|
1566
|
+
f"| `{col['name']}` | {data_type} | {nullable} | {default_val} | {comment} |"
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
# Primary keys
|
|
1570
|
+
if hasattr(table, 'primary_keys') and table.primary_keys:
|
|
1571
|
+
table_doc.append(f"\n**Primary Keys:** {', '.join([f'`{pk}`' for pk in table.primary_keys])}")
|
|
1572
|
+
|
|
1573
|
+
# Foreign keys
|
|
1574
|
+
if hasattr(table, 'foreign_keys') and table.foreign_keys:
|
|
1575
|
+
table_doc.append("\n**Foreign Keys:**")
|
|
1576
|
+
for fk in table.foreign_keys:
|
|
1577
|
+
if isinstance(fk, dict):
|
|
1578
|
+
table_doc.append(f"- `{fk.get('column')}` -> `{fk.get('referenced_table')}.{fk.get('referenced_column')}`")
|
|
1579
|
+
|
|
1580
|
+
# Indexes
|
|
1581
|
+
if hasattr(table, 'indexes') and table.indexes:
|
|
1582
|
+
table_doc.append(f"\n**Indexes:** {len(table.indexes)} indexes defined")
|
|
1583
|
+
|
|
1584
|
+
# CREATE TABLE statement for developers
|
|
1585
|
+
if user_role == UserRole.DEVELOPER:
|
|
1586
|
+
create_statement = self._generate_create_table_statement(table)
|
|
1587
|
+
if create_statement:
|
|
1588
|
+
table_doc.append(f"\n## CREATE TABLE Statement\n\n```sql\n{create_statement}\n```")
|
|
1589
|
+
|
|
1590
|
+
# Sample data (if available and requested)
|
|
1591
|
+
if hasattr(table, 'sample_data') and table.sample_data and len(table.sample_data) > 0:
|
|
1592
|
+
table_doc.append("\n## Sample Data\n")
|
|
1593
|
+
# Show first 3 rows as example
|
|
1594
|
+
sample_rows = table.sample_data[:3]
|
|
1595
|
+
if sample_rows:
|
|
1596
|
+
# Get column headers
|
|
1597
|
+
headers = list(sample_rows[0].keys()) if sample_rows else []
|
|
1598
|
+
if headers:
|
|
1599
|
+
# Create sample data table
|
|
1600
|
+
table_doc.append("| " + " | ".join(headers) + " |")
|
|
1601
|
+
table_doc.append("| " + " | ".join(['---'] * len(headers)) + " |")
|
|
1602
|
+
|
|
1603
|
+
for row in sample_rows:
|
|
1604
|
+
values = [str(row.get(h, '')) for h in headers]
|
|
1605
|
+
# Truncate long values
|
|
1606
|
+
values = [v[:50] + '...' if len(str(v)) > 50 else str(v) for v in values]
|
|
1607
|
+
table_doc.append("| " + " | ".join(values) + " |")
|
|
1608
|
+
|
|
1609
|
+
# Access statistics
|
|
1610
|
+
if hasattr(table, 'last_accessed') and table.last_accessed:
|
|
1611
|
+
table_doc.append(f"\n**Last Accessed:** {table.last_accessed}")
|
|
1612
|
+
if hasattr(table, 'access_frequency') and table.access_frequency:
|
|
1613
|
+
table_doc.append(f"**Access Frequency:** {table.access_frequency}")
|
|
1614
|
+
|
|
1615
|
+
documentation_parts.append("\n".join(table_doc))
|
|
1616
|
+
|
|
1617
|
+
return "\n\n---\n\n".join(documentation_parts)
|
|
1618
|
+
|
|
1619
|
+
def _generate_create_table_statement(self, table: TableMetadata) -> str:
|
|
1620
|
+
"""Generate CREATE TABLE statement from table metadata."""
|
|
1621
|
+
if not table.columns:
|
|
1622
|
+
return ""
|
|
1623
|
+
|
|
1624
|
+
create_parts = [f'CREATE TABLE {table.full_name} (']
|
|
1625
|
+
|
|
1626
|
+
column_definitions = []
|
|
1627
|
+
for col in table.columns:
|
|
1628
|
+
col_def = f' "{col["name"]}" {col["type"]}'
|
|
1629
|
+
|
|
1630
|
+
# Add NOT NULL constraint
|
|
1631
|
+
if not col.get('nullable', True):
|
|
1632
|
+
col_def += ' NOT NULL'
|
|
1633
|
+
|
|
1634
|
+
# Add DEFAULT value
|
|
1635
|
+
if col.get('default'):
|
|
1636
|
+
default_val = col['default']
|
|
1637
|
+
# Handle different default value types
|
|
1638
|
+
if default_val.lower() in ['now()', 'current_timestamp', 'current_date']:
|
|
1639
|
+
col_def += f' DEFAULT {default_val}'
|
|
1640
|
+
elif default_val.replace("'", "").replace('"', '').isdigit():
|
|
1641
|
+
col_def += f' DEFAULT {default_val}'
|
|
1642
|
+
else:
|
|
1643
|
+
col_def += f" DEFAULT '{default_val}'"
|
|
1644
|
+
|
|
1645
|
+
column_definitions.append(col_def)
|
|
1646
|
+
|
|
1647
|
+
# Add primary key constraint
|
|
1648
|
+
if hasattr(table, 'primary_keys') and table.primary_keys:
|
|
1649
|
+
pk_cols = ', '.join([f'"{pk}"' for pk in table.primary_keys])
|
|
1650
|
+
column_definitions.append(f' PRIMARY KEY ({pk_cols})')
|
|
1651
|
+
|
|
1652
|
+
create_parts.append(',\n'.join(column_definitions))
|
|
1653
|
+
create_parts.append(');')
|
|
1654
|
+
|
|
1655
|
+
# Add table comment if exists
|
|
1656
|
+
if table.comment:
|
|
1657
|
+
create_parts.append(f"\n\nCOMMENT ON TABLE {table.full_name} IS '{table.comment}';")
|
|
1658
|
+
|
|
1659
|
+
# Add column comments
|
|
1660
|
+
for col in table.columns:
|
|
1661
|
+
if col.get('comment'):
|
|
1662
|
+
create_parts.append(
|
|
1663
|
+
f'COMMENT ON COLUMN {table.full_name}."{col["name"]}" IS \'{col["comment"]}\';'
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
return '\n'.join(create_parts)
|
|
1667
|
+
|
|
1668
|
+
async def _build_schema_context(
|
|
1669
|
+
self,
|
|
1670
|
+
primary_schema: str,
|
|
1671
|
+
allowed_schemas: List[str],
|
|
1672
|
+
discovered_tables: List[TableMetadata] = None
|
|
1673
|
+
) -> str:
|
|
1674
|
+
"""
|
|
1675
|
+
Build schema context showing metadata of tables involved in the query.
|
|
1676
|
+
|
|
1677
|
+
Args:
|
|
1678
|
+
primary_schema: Primary schema name (for context)
|
|
1679
|
+
allowed_schemas: Allowed schemas (for context)
|
|
1680
|
+
discovered_tables: List of tables discovered for this query
|
|
1681
|
+
|
|
1682
|
+
Returns:
|
|
1683
|
+
Formatted metadata context of the involved tables
|
|
1684
|
+
"""
|
|
1685
|
+
|
|
1686
|
+
if not discovered_tables:
|
|
1687
|
+
return f"""**Query Context:**
|
|
1688
|
+
No specific tables identified for this query.
|
|
1689
|
+
|
|
1690
|
+
**Available Schemas:** {', '.join([f'`{s}`' for s in allowed_schemas])}
|
|
1691
|
+
**Primary Schema:** `{primary_schema}`
|
|
1692
|
+
|
|
1693
|
+
*Use table discovery tools to identify relevant tables for your query.*"""
|
|
1694
|
+
|
|
1695
|
+
context_parts = []
|
|
1696
|
+
|
|
1697
|
+
# Header
|
|
1698
|
+
context_parts.append("**TABLES INVOLVED IN QUERY**")
|
|
1699
|
+
context_parts.append("=" * 50)
|
|
1700
|
+
|
|
1701
|
+
for i, table in enumerate(discovered_tables, 1):
|
|
1702
|
+
# Table header
|
|
1703
|
+
context_parts.append(f"\n**{i}. {table.full_name}**")
|
|
1704
|
+
context_parts.append(f" Type: {table.table_type}")
|
|
1705
|
+
if table.row_count is not None and table.row_count >= 0:
|
|
1706
|
+
context_parts.append(f" Rows: {table.row_count:,}")
|
|
1707
|
+
if table.comment:
|
|
1708
|
+
context_parts.append(f" Description: {table.comment}")
|
|
1709
|
+
|
|
1710
|
+
# Column information in compact format
|
|
1711
|
+
if table.columns:
|
|
1712
|
+
context_parts.append(f"\n **Columns ({len(table.columns)}):**")
|
|
1713
|
+
|
|
1714
|
+
# Group columns by type for better readability
|
|
1715
|
+
column_groups = {}
|
|
1716
|
+
for col in table.columns:
|
|
1717
|
+
col_type = col.get('type', 'unknown')
|
|
1718
|
+
# Simplify type names for readability
|
|
1719
|
+
simple_type = self._simplify_column_type(col_type)
|
|
1720
|
+
if simple_type not in column_groups:
|
|
1721
|
+
column_groups[simple_type] = []
|
|
1722
|
+
column_groups[simple_type].append(col)
|
|
1723
|
+
|
|
1724
|
+
# Display columns by type
|
|
1725
|
+
for type_name, cols in column_groups.items():
|
|
1726
|
+
col_names = []
|
|
1727
|
+
for col in cols:
|
|
1728
|
+
name = col['name']
|
|
1729
|
+
# Add indicators for special columns
|
|
1730
|
+
if not col.get('nullable', True):
|
|
1731
|
+
name += "*" # Required field
|
|
1732
|
+
if col.get('default'):
|
|
1733
|
+
name += "°" # Has default
|
|
1734
|
+
col_names.append(name)
|
|
1735
|
+
|
|
1736
|
+
context_parts.append(f" • {type_name}: {', '.join(col_names)}")
|
|
1737
|
+
|
|
1738
|
+
# Primary key
|
|
1739
|
+
if hasattr(table, 'primary_keys') and table.primary_keys:
|
|
1740
|
+
pk_list = ', '.join(table.primary_keys)
|
|
1741
|
+
context_parts.append(f" **Primary Key:** {pk_list}")
|
|
1742
|
+
|
|
1743
|
+
# Foreign keys (relationships)
|
|
1744
|
+
if hasattr(table, 'foreign_keys') and table.foreign_keys:
|
|
1745
|
+
context_parts.append(f" **Relationships:**")
|
|
1746
|
+
for fk in table.foreign_keys[:3]: # Limit to first 3 to avoid clutter
|
|
1747
|
+
if isinstance(fk, dict):
|
|
1748
|
+
ref_table = fk.get('referenced_table', 'unknown')
|
|
1749
|
+
ref_col = fk.get('referenced_column', 'unknown')
|
|
1750
|
+
fk_col = fk.get('column', 'unknown')
|
|
1751
|
+
context_parts.append(f" • {fk_col} → {ref_table}.{ref_col}")
|
|
1752
|
+
|
|
1753
|
+
if len(table.foreign_keys) > 3:
|
|
1754
|
+
context_parts.append(f" • ... and {len(table.foreign_keys) - 3} more")
|
|
1755
|
+
|
|
1756
|
+
# Indexes (for performance context)
|
|
1757
|
+
if hasattr(table, 'indexes') and table.indexes:
|
|
1758
|
+
idx_count = len(table.indexes)
|
|
1759
|
+
context_parts.append(f" **Indexes:** {idx_count} defined")
|
|
1760
|
+
|
|
1761
|
+
# Show a few key indexes
|
|
1762
|
+
key_indexes = []
|
|
1763
|
+
for idx in table.indexes[:2]: # Show first 2
|
|
1764
|
+
if isinstance(idx, dict):
|
|
1765
|
+
idx_name = idx.get('name', 'unnamed')
|
|
1766
|
+
idx_cols = idx.get('columns', [])
|
|
1767
|
+
if idx_cols:
|
|
1768
|
+
key_indexes.append(f"{idx_name}({', '.join(idx_cols)})")
|
|
1769
|
+
|
|
1770
|
+
if key_indexes:
|
|
1771
|
+
context_parts.append(f" • Key indexes: {', '.join(key_indexes)}")
|
|
1772
|
+
|
|
1773
|
+
# Add usage legend
|
|
1774
|
+
context_parts.append("\n" + "=" * 50)
|
|
1775
|
+
context_parts.append("**LEGEND:**")
|
|
1776
|
+
context_parts.append("• Column* = Required (NOT NULL)")
|
|
1777
|
+
context_parts.append("• Column° = Has default value")
|
|
1778
|
+
context_parts.append("• Relationships show foreign key connections")
|
|
1779
|
+
|
|
1780
|
+
# Query development tips specific to these tables
|
|
1781
|
+
context_parts.append("\n**QUERY DEVELOPMENT TIPS:**")
|
|
1782
|
+
|
|
1783
|
+
# Generate contextual tips based on discovered tables
|
|
1784
|
+
tips = self._generate_table_specific_tips(discovered_tables)
|
|
1785
|
+
context_parts.extend([f"• {tip}" for tip in tips])
|
|
1786
|
+
|
|
1787
|
+
return "\n".join(context_parts)
|
|
1788
|
+
|
|
1789
|
+
def _simplify_column_type(self, col_type: str) -> str:
|
|
1790
|
+
"""Simplify PostgreSQL column types for readable grouping."""
|
|
1791
|
+
col_type = col_type.lower()
|
|
1792
|
+
|
|
1793
|
+
# Group similar types
|
|
1794
|
+
if 'varchar' in col_type or 'character varying' in col_type or 'text' in col_type:
|
|
1795
|
+
return 'Text'
|
|
1796
|
+
elif 'int' in col_type or 'serial' in col_type:
|
|
1797
|
+
return 'Integer'
|
|
1798
|
+
elif 'numeric' in col_type or 'decimal' in col_type or 'float' in col_type or 'double' in col_type:
|
|
1799
|
+
return 'Number'
|
|
1800
|
+
elif 'timestamp' in col_type or 'date' in col_type or 'time' in col_type:
|
|
1801
|
+
return 'DateTime'
|
|
1802
|
+
elif 'boolean' in col_type:
|
|
1803
|
+
return 'Boolean'
|
|
1804
|
+
elif 'uuid' in col_type:
|
|
1805
|
+
return 'UUID'
|
|
1806
|
+
elif 'json' in col_type:
|
|
1807
|
+
return 'JSON'
|
|
1808
|
+
elif 'array' in col_type:
|
|
1809
|
+
return 'Array'
|
|
1810
|
+
else:
|
|
1811
|
+
return col_type.title()
|
|
1812
|
+
|
|
1813
|
+
def _generate_table_specific_tips(self, discovered_tables: List[TableMetadata]) -> List[str]:
|
|
1814
|
+
"""Generate query development tips specific to the discovered tables."""
|
|
1815
|
+
tips = []
|
|
1816
|
+
|
|
1817
|
+
if not discovered_tables:
|
|
1818
|
+
return ["No tables discovered for specific tips"]
|
|
1819
|
+
|
|
1820
|
+
# Analyze the tables for specific tips
|
|
1821
|
+
table_names = [table.tablename for table in discovered_tables]
|
|
1822
|
+
total_columns = sum(len(table.columns) for table in discovered_tables if table.columns)
|
|
1823
|
+
|
|
1824
|
+
# Tip about table joining
|
|
1825
|
+
if len(discovered_tables) > 1:
|
|
1826
|
+
tips.append(f"Multiple tables detected - consider JOIN relationships between {', '.join(table_names)}")
|
|
1827
|
+
|
|
1828
|
+
# Tip about column selection
|
|
1829
|
+
if total_columns > 20:
|
|
1830
|
+
tips.append("Many columns available - use SELECT specific_columns instead of SELECT * for better performance")
|
|
1831
|
+
|
|
1832
|
+
# Tip about primary keys for efficient queries
|
|
1833
|
+
pk_tables = [t.tablename for t in discovered_tables if hasattr(t, 'primary_keys') and t.primary_keys]
|
|
1834
|
+
if pk_tables:
|
|
1835
|
+
tips.append(f"Use primary keys for efficient lookups: {', '.join(pk_tables)}")
|
|
1836
|
+
|
|
1837
|
+
# Tip about large tables
|
|
1838
|
+
large_tables = [t.tablename for t in discovered_tables if t.row_count and t.row_count > 100000]
|
|
1839
|
+
if large_tables:
|
|
1840
|
+
tips.append(f"Large tables detected ({', '.join(large_tables)}) - consider LIMIT clauses and WHERE filtering")
|
|
1841
|
+
|
|
1842
|
+
# Tip about indexes
|
|
1843
|
+
indexed_tables = [t.tablename for t in discovered_tables if hasattr(t, 'indexes') and t.indexes]
|
|
1844
|
+
if indexed_tables:
|
|
1845
|
+
tips.append(f"Indexed tables available - leverage existing indexes for optimal performance")
|
|
1846
|
+
|
|
1847
|
+
# Default tip if no specific tips generated
|
|
1848
|
+
if not tips:
|
|
1849
|
+
tips.append(
|
|
1850
|
+
f"Focus on the {len(discovered_tables)} table(s) structure above for efficient query design"
|
|
1851
|
+
)
|
|
1852
|
+
|
|
1853
|
+
return tips[:4] # Limit to 4 tips
|
|
1854
|
+
|
|
1855
|
+
async def _get_schema_counts_direct(self, schema_name: str) -> Tuple[int, int]:
|
|
1856
|
+
"""Get table and view counts directly from information_schema."""
|
|
1857
|
+
try:
|
|
1858
|
+
async with self.session_maker() as session:
|
|
1859
|
+
# Count tables
|
|
1860
|
+
table_query = text("""
|
|
1861
|
+
SELECT COUNT(*)
|
|
1862
|
+
FROM information_schema.tables
|
|
1863
|
+
WHERE table_schema = :schema_name
|
|
1864
|
+
AND table_type = 'BASE TABLE'
|
|
1865
|
+
""")
|
|
1866
|
+
table_result = await session.execute(table_query, {"schema_name": schema_name})
|
|
1867
|
+
table_count = table_result.scalar() or 0
|
|
1868
|
+
|
|
1869
|
+
# Count views
|
|
1870
|
+
view_query = text("""
|
|
1871
|
+
SELECT COUNT(*)
|
|
1872
|
+
FROM information_schema.views
|
|
1873
|
+
WHERE table_schema = :schema_name
|
|
1874
|
+
""")
|
|
1875
|
+
view_result = await session.execute(view_query, {"schema_name": schema_name})
|
|
1876
|
+
view_count = view_result.scalar() or 0
|
|
1877
|
+
|
|
1878
|
+
return table_count, view_count
|
|
1879
|
+
|
|
1880
|
+
except Exception as e:
|
|
1881
|
+
self.logger.error(f"Direct schema count failed for {schema_name}: {e}")
|
|
1882
|
+
return 0, 0
|
|
1883
|
+
|
|
1884
|
+
async def _validate_user_sql(self, sql_query: str, metadata_context: str, context: Optional[str] = None) -> tuple[str, AIMessage]:
|
|
1885
|
+
"""Validate user-provided SQL."""
|
|
1886
|
+
|
|
1887
|
+
system_prompt = f"""
|
|
1888
|
+
You are validating SQL for multi-schema access.
|
|
1889
|
+
|
|
1890
|
+
```sql
|
|
1891
|
+
{sql_query}
|
|
1892
|
+
```
|
|
1893
|
+
|
|
1894
|
+
**Context Information:**
|
|
1895
|
+
{context}
|
|
1896
|
+
|
|
1897
|
+
**Primary Schema:** {self.primary_schema}
|
|
1898
|
+
**Allowed Schemas:** {', '.join(self.allowed_schemas)}
|
|
1899
|
+
|
|
1900
|
+
**Available Schema Information:**
|
|
1901
|
+
{metadata_context}
|
|
1902
|
+
|
|
1903
|
+
**Validation Tasks:**
|
|
1904
|
+
1. Check syntax correctness
|
|
1905
|
+
2. Verify table/column existence
|
|
1906
|
+
3. Ensure queries only access allowed schemas: {', '.join(self.allowed_schemas)}
|
|
1907
|
+
4. Identify potential performance issues
|
|
1908
|
+
5. Suggest improvements
|
|
1909
|
+
|
|
1910
|
+
Provide detailed validation results.
|
|
1911
|
+
"""
|
|
1912
|
+
async with self._llm as client:
|
|
1913
|
+
llm_response = await client.ask(
|
|
1914
|
+
prompt=f"Validate this SQL query:\n\n```sql\n{sql_query}\n```",
|
|
1915
|
+
system_prompt=system_prompt,
|
|
1916
|
+
temperature=0.0
|
|
1917
|
+
)
|
|
1918
|
+
|
|
1919
|
+
validation_text = str(llm_response.output) if llm_response.output else str(llm_response.response)
|
|
1920
|
+
return validation_text, llm_response
|
|
1921
|
+
|
|
1922
|
+
async def _execute_query(
|
|
1923
|
+
self,
|
|
1924
|
+
sql_query: str,
|
|
1925
|
+
route: RouteDecision,
|
|
1926
|
+
enable_retry: bool = True,
|
|
1927
|
+
retry_config: Optional[QueryRetryConfig] = None
|
|
1928
|
+
) -> QueryExecutionResponse:
|
|
1929
|
+
"""Execute SQL query with schema security."""
|
|
1930
|
+
|
|
1931
|
+
start_time = datetime.now()
|
|
1932
|
+
# Configure execution options based on components
|
|
1933
|
+
options = dict(route.execution_options)
|
|
1934
|
+
|
|
1935
|
+
# Component-specific configuration
|
|
1936
|
+
if OutputComponent.EXECUTION_PLAN in route.components:
|
|
1937
|
+
options['explain_analyze'] = True
|
|
1938
|
+
|
|
1939
|
+
# Apply data limits based on role and components
|
|
1940
|
+
if route.include_full_data:
|
|
1941
|
+
options['limit'] = None # No limit for business users
|
|
1942
|
+
elif route.data_limit:
|
|
1943
|
+
options['limit'] = route.data_limit
|
|
1944
|
+
|
|
1945
|
+
if route.user_role.value == 'database_admin':
|
|
1946
|
+
options['timeout'] = 60
|
|
1947
|
+
else:
|
|
1948
|
+
options.setdefault('timeout', 30)
|
|
1949
|
+
|
|
1950
|
+
# Retry Handler when enable_retry is True
|
|
1951
|
+
if enable_retry:
|
|
1952
|
+
retry_handler = SQLRetryHandler(self, retry_config or QueryRetryConfig())
|
|
1953
|
+
retry_count = 0
|
|
1954
|
+
last_error = None
|
|
1955
|
+
query_history = [] # Track all attempts
|
|
1956
|
+
|
|
1957
|
+
while retry_count <= retry_handler.config.max_retries:
|
|
1958
|
+
try:
|
|
1959
|
+
self.logger.debug(f"🔄 QUERY ATTEMPT {retry_count + 1}: Executing SQL")
|
|
1960
|
+
# Execute the query
|
|
1961
|
+
result = await self._execute_query_internal(sql_query, options)
|
|
1962
|
+
# Success!
|
|
1963
|
+
if retry_count > 0:
|
|
1964
|
+
self.logger.info(
|
|
1965
|
+
f"✅ QUERY SUCCESS: Fixed after {retry_count + 1} retries"
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
return result
|
|
1969
|
+
except Exception as e:
|
|
1970
|
+
self.logger.warning(
|
|
1971
|
+
f"❌ QUERY FAILED (attempt {retry_count + 1}): {e}"
|
|
1972
|
+
)
|
|
1973
|
+
|
|
1974
|
+
query_history.append({
|
|
1975
|
+
"attempt": retry_count + 1,
|
|
1976
|
+
"query": sql_query,
|
|
1977
|
+
"error": str(e),
|
|
1978
|
+
"error_type": type(e).__name__
|
|
1979
|
+
})
|
|
1980
|
+
|
|
1981
|
+
last_error = e
|
|
1982
|
+
|
|
1983
|
+
# Check if this is a retryable error
|
|
1984
|
+
if not retry_handler._is_retryable_error(e):
|
|
1985
|
+
self.logger.info(f"🚫 NON-RETRYABLE ERROR: {type(e).__name__}")
|
|
1986
|
+
break
|
|
1987
|
+
|
|
1988
|
+
# Check if we've hit max retries
|
|
1989
|
+
if retry_count >= retry_handler.config.max_retries:
|
|
1990
|
+
self.logger.info(f"🛑 MAX RETRIES REACHED: {retry_count}")
|
|
1991
|
+
break
|
|
1992
|
+
|
|
1993
|
+
# Try to fix the query
|
|
1994
|
+
self.logger.info(
|
|
1995
|
+
f"🔧 ATTEMPTING QUERY FIX: Retry {retry_count + 1}"
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
try:
|
|
1999
|
+
fixed_query = await self._fix_query(
|
|
2000
|
+
original_query=sql_query,
|
|
2001
|
+
error=e,
|
|
2002
|
+
retry_count=retry_count,
|
|
2003
|
+
query_history=query_history
|
|
2004
|
+
)
|
|
2005
|
+
if fixed_query and fixed_query.strip() != sql_query.strip():
|
|
2006
|
+
sql_query = fixed_query
|
|
2007
|
+
retry_count += 1
|
|
2008
|
+
else:
|
|
2009
|
+
self.logger.warning(
|
|
2010
|
+
f"🔧 NO QUERY FIX: LLM returned same or empty query"
|
|
2011
|
+
)
|
|
2012
|
+
break
|
|
2013
|
+
except Exception as fix_error:
|
|
2014
|
+
self.logger.error(
|
|
2015
|
+
f"🔧 QUERY FIX FAILED: {fix_error}"
|
|
2016
|
+
)
|
|
2017
|
+
break
|
|
2018
|
+
# All retries failed, return error response
|
|
2019
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
2020
|
+
return QueryExecutionResponse(
|
|
2021
|
+
success=False,
|
|
2022
|
+
data=None,
|
|
2023
|
+
row_count=0,
|
|
2024
|
+
execution_time_ms=execution_time,
|
|
2025
|
+
schema_used=self.primary_schema,
|
|
2026
|
+
error_message=f"Query failed after {retry_count} retries. Last error: {last_error}",
|
|
2027
|
+
query_plan=None,
|
|
2028
|
+
metadata={
|
|
2029
|
+
"retry_count": retry_count,
|
|
2030
|
+
"query_history": query_history,
|
|
2031
|
+
"last_error_type": type(last_error).__name__ if last_error else None
|
|
2032
|
+
}
|
|
2033
|
+
)
|
|
2034
|
+
else:
|
|
2035
|
+
# No retry, single attempt with error handling
|
|
2036
|
+
return await self._execute_query_safe(sql_query, options)
|
|
2037
|
+
|
|
2038
|
+
async def _execute_query_internal(
|
|
2039
|
+
self,
|
|
2040
|
+
sql_query: str,
|
|
2041
|
+
options: Dict[str, Any]
|
|
2042
|
+
) -> QueryExecutionResponse:
|
|
2043
|
+
"""Execute query and raise exceptions (don't catch them) for retry mechanism."""
|
|
2044
|
+
|
|
2045
|
+
start_time = datetime.now()
|
|
2046
|
+
|
|
2047
|
+
# Validate query targets correct schemas
|
|
2048
|
+
if not self._validate_schema_security(sql_query):
|
|
2049
|
+
raise ValueError(
|
|
2050
|
+
f"Query attempts to access schemas outside of allowed list: {self.allowed_schemas}"
|
|
2051
|
+
)
|
|
2052
|
+
|
|
2053
|
+
# Execute query - LET EXCEPTIONS PROPAGATE for retry mechanism
|
|
2054
|
+
async with self.session_maker() as session:
|
|
2055
|
+
# Set search path for security
|
|
2056
|
+
search_path = ','.join(self.allowed_schemas)
|
|
2057
|
+
await session.execute(text(f"SET search_path = '{search_path}'"))
|
|
2058
|
+
|
|
2059
|
+
# Add timeout
|
|
2060
|
+
timeout = options.get('timeout', 30)
|
|
2061
|
+
await session.execute(text(f"SET statement_timeout = '{timeout}s'"))
|
|
2062
|
+
|
|
2063
|
+
# Execute main query
|
|
2064
|
+
query_plan = None
|
|
2065
|
+
if options.get('explain_analyze', False):
|
|
2066
|
+
# Get query plan first
|
|
2067
|
+
plan_result = await session.execute(text(f"EXPLAIN ANALYZE {sql_query}"))
|
|
2068
|
+
query_plan = "\n".join([row[0] for row in plan_result.fetchall()])
|
|
2069
|
+
|
|
2070
|
+
# Execute actual query - DON'T CATCH EXCEPTIONS HERE
|
|
2071
|
+
result = await session.execute(text(sql_query))
|
|
2072
|
+
|
|
2073
|
+
if sql_query.strip().upper().startswith('SELECT'):
|
|
2074
|
+
# Handle SELECT queries
|
|
2075
|
+
rows = result.fetchall()
|
|
2076
|
+
columns = list(result.keys()) if rows else []
|
|
2077
|
+
|
|
2078
|
+
# Apply limit
|
|
2079
|
+
limit = options.get('limit', 1000)
|
|
2080
|
+
limited_rows = rows[:limit] if len(rows) > limit else rows
|
|
2081
|
+
|
|
2082
|
+
# Convert to list of dicts
|
|
2083
|
+
data = [dict(zip(columns, row)) for row in limited_rows]
|
|
2084
|
+
row_count = len(rows) # Original count
|
|
2085
|
+
else:
|
|
2086
|
+
# Handle non-SELECT queries
|
|
2087
|
+
data = None
|
|
2088
|
+
columns = []
|
|
2089
|
+
row_count = result.rowcount
|
|
2090
|
+
|
|
2091
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
2092
|
+
|
|
2093
|
+
return QueryExecutionResponse(
|
|
2094
|
+
success=True,
|
|
2095
|
+
data=data,
|
|
2096
|
+
row_count=row_count,
|
|
2097
|
+
execution_time_ms=execution_time,
|
|
2098
|
+
columns=columns,
|
|
2099
|
+
query_plan=query_plan,
|
|
2100
|
+
schema_used=self.primary_schema
|
|
2101
|
+
)
|
|
2102
|
+
|
|
2103
|
+
async def _execute_query_safe(
|
|
2104
|
+
self,
|
|
2105
|
+
sql_query: str,
|
|
2106
|
+
options: Dict[str, Any]
|
|
2107
|
+
) -> QueryExecutionResponse:
|
|
2108
|
+
"""Execute query with error handling (for non-retry scenarios)."""
|
|
2109
|
+
|
|
2110
|
+
start_time = datetime.now()
|
|
2111
|
+
|
|
2112
|
+
try:
|
|
2113
|
+
# Use the internal method that raises exceptions
|
|
2114
|
+
return await self._execute_query_internal(sql_query, options)
|
|
2115
|
+
|
|
2116
|
+
except Exception as e:
|
|
2117
|
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
2118
|
+
|
|
2119
|
+
self.logger.error(f"Query execution failed: {e}")
|
|
2120
|
+
|
|
2121
|
+
return QueryExecutionResponse(
|
|
2122
|
+
success=False,
|
|
2123
|
+
data=None,
|
|
2124
|
+
row_count=0,
|
|
2125
|
+
execution_time_ms=execution_time,
|
|
2126
|
+
error_message=str(e),
|
|
2127
|
+
schema_used=self.primary_schema
|
|
2128
|
+
)
|
|
2129
|
+
|
|
2130
|
+
async def _fix_query(
|
|
2131
|
+
self,
|
|
2132
|
+
original_query: str,
|
|
2133
|
+
error: Exception,
|
|
2134
|
+
retry_count: int,
|
|
2135
|
+
query_history: List[Dict[str, Any]]
|
|
2136
|
+
) -> Optional[str]:
|
|
2137
|
+
"""Use LLM to fix a failed SQL query based on the error."""
|
|
2138
|
+
|
|
2139
|
+
retry_handler = SQLRetryHandler(self)
|
|
2140
|
+
|
|
2141
|
+
# Extract problematic table/column info
|
|
2142
|
+
table_name, column_name = retry_handler._extract_table_column_from_error(
|
|
2143
|
+
original_query, error
|
|
2144
|
+
)
|
|
2145
|
+
|
|
2146
|
+
# Get sample data if possible
|
|
2147
|
+
sample_data = ""
|
|
2148
|
+
if table_name and column_name:
|
|
2149
|
+
sample_data = await retry_handler._get_sample_data_for_error(
|
|
2150
|
+
self.primary_schema, table_name, column_name
|
|
2151
|
+
)
|
|
2152
|
+
|
|
2153
|
+
# Build error context
|
|
2154
|
+
error_context = f"""
|
|
2155
|
+
**QUERY EXECUTION ERROR:**
|
|
2156
|
+
Error Type: {type(error).__name__}
|
|
2157
|
+
Error Message: {str(error)}
|
|
2158
|
+
|
|
2159
|
+
**FAILED QUERY:**
|
|
2160
|
+
```sql
|
|
2161
|
+
{original_query}
|
|
2162
|
+
```
|
|
2163
|
+
|
|
2164
|
+
**RETRY ATTEMPT:** {retry_count + 1} of {retry_handler.config.max_retries}
|
|
2165
|
+
|
|
2166
|
+
{sample_data}
|
|
2167
|
+
|
|
2168
|
+
**PREVIOUS ATTEMPTS:**
|
|
2169
|
+
{self._format_query_history(query_history)}
|
|
2170
|
+
"""
|
|
2171
|
+
|
|
2172
|
+
# Enhanced system prompt for query fixing
|
|
2173
|
+
fix_prompt = f"""
|
|
2174
|
+
You are a PostgreSQL expert specializing in fixing SQL query errors.
|
|
2175
|
+
|
|
2176
|
+
**PRIMARY TASK:** Fix the failed SQL query based on the error message and sample data.
|
|
2177
|
+
|
|
2178
|
+
**COMMON ERROR PATTERNS & FIXES:**
|
|
2179
|
+
|
|
2180
|
+
💰 **Currency/Number Format Errors:**
|
|
2181
|
+
- Error: "invalid input syntax for type numeric: '1,999.99'"
|
|
2182
|
+
- Fix: Remove commas and currency symbols properly
|
|
2183
|
+
- Example: `CAST(REPLACE(REPLACE(pricing, '$', ''), ',', '') AS NUMERIC)`
|
|
2184
|
+
|
|
2185
|
+
📝 **String/Text Conversion Issues:**
|
|
2186
|
+
- Error: Type conversion failures
|
|
2187
|
+
- Fix: Use proper casting with text cleaning
|
|
2188
|
+
- Example: `CAST(TRIM(column_name) AS INTEGER)`
|
|
2189
|
+
|
|
2190
|
+
🔤 **Column/Table Name Issues:**
|
|
2191
|
+
- Error: "column does not exist"
|
|
2192
|
+
- Fix: Check exact column names from metadata, use proper quoting
|
|
2193
|
+
- Example: Use "column_name" if names have special characters
|
|
2194
|
+
|
|
2195
|
+
**SCHEMA CONTEXT:**
|
|
2196
|
+
Primary Schema: {self.primary_schema}
|
|
2197
|
+
Available Schemas: {', '.join(self.allowed_schemas)}
|
|
2198
|
+
|
|
2199
|
+
{error_context}
|
|
2200
|
+
|
|
2201
|
+
**FIXING INSTRUCTIONS:**
|
|
2202
|
+
1. Analyze the error message carefully
|
|
2203
|
+
2. Look at the sample data to understand the actual format
|
|
2204
|
+
3. Modify the query to handle the data format properly
|
|
2205
|
+
4. Keep the same business logic (ORDER BY, LIMIT, etc.)
|
|
2206
|
+
5. Only change what's necessary to fix the error
|
|
2207
|
+
6. Test your logic against the sample data shown
|
|
2208
|
+
|
|
2209
|
+
**OUTPUT:** Return ONLY the corrected SQL query, no explanations.
|
|
2210
|
+
"""
|
|
2211
|
+
try:
|
|
2212
|
+
response = await self._llm.ask(
|
|
2213
|
+
prompt="Fix the failing SQL query based on the error details above.",
|
|
2214
|
+
system_prompt=fix_prompt,
|
|
2215
|
+
temperature=0.0 # Deterministic fixes
|
|
2216
|
+
)
|
|
2217
|
+
|
|
2218
|
+
fixed_query = self._extract_sql_from_response(
|
|
2219
|
+
str(response.output) if response.output else str(response.response)
|
|
2220
|
+
)
|
|
2221
|
+
|
|
2222
|
+
if fixed_query:
|
|
2223
|
+
self.logger.debug(f"FIXED QUERY: {fixed_query}")
|
|
2224
|
+
return fixed_query
|
|
2225
|
+
else:
|
|
2226
|
+
self.logger.warning(f"LLM FIX: No SQL query found in response")
|
|
2227
|
+
return None
|
|
2228
|
+
|
|
2229
|
+
except Exception as e:
|
|
2230
|
+
self.logger.error(f"LLM FIX ERROR: {e}")
|
|
2231
|
+
return None
|
|
2232
|
+
|
|
2233
|
+
def _format_query_history(self, query_history: List[Dict[str, Any]]) -> str:
|
|
2234
|
+
"""Format query history for LLM context."""
|
|
2235
|
+
if not query_history:
|
|
2236
|
+
return "No previous attempts."
|
|
2237
|
+
|
|
2238
|
+
formatted = []
|
|
2239
|
+
for attempt in query_history:
|
|
2240
|
+
formatted.append(
|
|
2241
|
+
f"Attempt {attempt['attempt']}: {attempt['error_type']} - {attempt['error']}"
|
|
2242
|
+
)
|
|
2243
|
+
|
|
2244
|
+
return "\n".join(formatted)
|
|
2245
|
+
|
|
2246
|
+
def _validate_schema_security(self, sql_query: str) -> bool:
|
|
2247
|
+
"""Ensure query only accesses authorized schemas."""
|
|
2248
|
+
query_upper = sql_query.upper()
|
|
2249
|
+
|
|
2250
|
+
# Check for unauthorized schema references
|
|
2251
|
+
unauthorized_patterns = [
|
|
2252
|
+
r'\bFROM\s+(?!")(\w+)\.', # FROM schema.table without quotes
|
|
2253
|
+
r'\bJOIN\s+(?!")(\w+)\.', # JOIN schema.table without quotes
|
|
2254
|
+
r'\bUPDATE\s+(?!")(\w+)\.', # UPDATE schema.table without quotes
|
|
2255
|
+
r'\bINSERT\s+INTO\s+(?!")(\w+)\.', # INSERT INTO schema.table without quotes
|
|
2256
|
+
]
|
|
2257
|
+
|
|
2258
|
+
for pattern in unauthorized_patterns:
|
|
2259
|
+
matches = re.findall(pattern, query_upper)
|
|
2260
|
+
for match in matches:
|
|
2261
|
+
if match.upper() not in [schema.upper() for schema in self.allowed_schemas]:
|
|
2262
|
+
self.logger.warning(f"Query attempts to access unauthorized schema: {match}")
|
|
2263
|
+
return False
|
|
2264
|
+
|
|
2265
|
+
# Additional security checks could be added here
|
|
2266
|
+
return True
|
|
2267
|
+
|
|
2268
|
+
def _extract_sql_from_response(self, response_text: str) -> str:
|
|
2269
|
+
"""Extract SQL query from LLM response."""
|
|
2270
|
+
sql_patterns = [
|
|
2271
|
+
r'```sql\s*(.*?)\s*```', # ```sql with optional whitespace
|
|
2272
|
+
r'```SQL\s*(.*?)\s*```', # ```SQL (uppercase)
|
|
2273
|
+
r'```\s*(SELECT.*?(?:;|\Z))', # ``` with SELECT (no sql label)
|
|
2274
|
+
r'```\s*(WITH.*?(?:;|\Z))', # ``` with WITH (no sql label)
|
|
2275
|
+
]
|
|
2276
|
+
|
|
2277
|
+
for pattern in sql_patterns:
|
|
2278
|
+
matches = re.findall(pattern, response_text, re.DOTALL | re.IGNORECASE)
|
|
2279
|
+
if matches:
|
|
2280
|
+
sql = matches[0].strip()
|
|
2281
|
+
if sql:
|
|
2282
|
+
self.logger.debug(f"SQL EXTRACTED via pattern: {pattern[:20]}...")
|
|
2283
|
+
return sql
|
|
2284
|
+
|
|
2285
|
+
lines = response_text.split('\n')
|
|
2286
|
+
sql_lines = []
|
|
2287
|
+
in_sql = False
|
|
2288
|
+
|
|
2289
|
+
for line in lines:
|
|
2290
|
+
line_stripped = line.strip()
|
|
2291
|
+
line_upper = line_stripped.upper()
|
|
2292
|
+
|
|
2293
|
+
# Start collecting SQL when we see a SQL keyword
|
|
2294
|
+
if any(line_upper.startswith(kw) for kw in ['SELECT', 'WITH', 'INSERT', 'UPDATE', 'DELETE']):
|
|
2295
|
+
in_sql = True
|
|
2296
|
+
sql_lines.append(line_stripped)
|
|
2297
|
+
elif in_sql:
|
|
2298
|
+
# Continue collecting until we hit a terminator or empty line
|
|
2299
|
+
if line_stripped.endswith(';'):
|
|
2300
|
+
sql_lines.append(line_stripped)
|
|
2301
|
+
break
|
|
2302
|
+
elif not line_stripped:
|
|
2303
|
+
break
|
|
2304
|
+
elif line_stripped.startswith('**') or line_stripped.startswith('#'):
|
|
2305
|
+
# Stop at markdown headers or emphasis
|
|
2306
|
+
break
|
|
2307
|
+
else:
|
|
2308
|
+
sql_lines.append(line_stripped)
|
|
2309
|
+
|
|
2310
|
+
if sql_lines:
|
|
2311
|
+
sql_query = '\n'.join(sql_lines)
|
|
2312
|
+
self.logger.debug(f"SQL EXTRACTED via fallback parsing")
|
|
2313
|
+
return sql_query
|
|
2314
|
+
|
|
2315
|
+
# Last resort: return original if it contains SQL keywords
|
|
2316
|
+
if any(kw in response_text.upper() for kw in ['SELECT', 'FROM', 'WHERE']):
|
|
2317
|
+
self.logger.warning("Using entire response as SQL (last resort)")
|
|
2318
|
+
return response_text.strip()
|
|
2319
|
+
|
|
2320
|
+
self.logger.warning("No SQL found in response")
|
|
2321
|
+
return ""
|
|
2322
|
+
|
|
2323
|
+
def _format_as_text(
|
|
2324
|
+
self,
|
|
2325
|
+
db_response: DatabaseResponse,
|
|
2326
|
+
user_role: UserRole,
|
|
2327
|
+
discovered_tables: List[TableMetadata]
|
|
2328
|
+
) -> str:
|
|
2329
|
+
"""Format response as readable text based on user role."""
|
|
2330
|
+
sections = []
|
|
2331
|
+
if db_response.documentation and len(db_response.documentation) > 100:
|
|
2332
|
+
return db_response.documentation
|
|
2333
|
+
|
|
2334
|
+
# Role-specific formatting preferences
|
|
2335
|
+
if user_role == UserRole.BUSINESS_USER:
|
|
2336
|
+
# Simple, data-focused format
|
|
2337
|
+
if db_response.data is not None:
|
|
2338
|
+
if isinstance(db_response.data, pd.DataFrame):
|
|
2339
|
+
sections.append(
|
|
2340
|
+
f"**Results:** {len(db_response.data)} records found"
|
|
2341
|
+
)
|
|
2342
|
+
else:
|
|
2343
|
+
sections.append(
|
|
2344
|
+
f"**Results:** {db_response.row_count} records found"
|
|
2345
|
+
)
|
|
2346
|
+
|
|
2347
|
+
elif user_role == UserRole.DEVELOPER:
|
|
2348
|
+
# For developers requesting metadata, prioritize documentation
|
|
2349
|
+
if db_response.documentation:
|
|
2350
|
+
sections.append(db_response.documentation)
|
|
2351
|
+
elif discovered_tables:
|
|
2352
|
+
# Fallback to basic table info if no documentation generated
|
|
2353
|
+
for table in discovered_tables[:1]: # Show first table
|
|
2354
|
+
sections.append(f"**Table Found:** {table.full_name}")
|
|
2355
|
+
sections.append(f"**Columns:** {len(table.columns)} columns")
|
|
2356
|
+
if table.columns:
|
|
2357
|
+
col_list = ', '.join([f"`{col['name']}`" for col in table.columns[:5]])
|
|
2358
|
+
if len(table.columns) > 5:
|
|
2359
|
+
col_list += f", ... and {len(table.columns) - 5} more"
|
|
2360
|
+
sections.append(f"**Column Names:** {col_list}")
|
|
2361
|
+
|
|
2362
|
+
# Technical focus with examples ONLY if no documentation
|
|
2363
|
+
if not db_response.documentation:
|
|
2364
|
+
if db_response.query:
|
|
2365
|
+
sections.append(f"**SQL Query:**\n```sql\n{db_response.query}\n```")
|
|
2366
|
+
if db_response.examples:
|
|
2367
|
+
examples_text = "\n".join([f"```sql\n{ex}\n```" for ex in db_response.examples])
|
|
2368
|
+
sections.append(f"**Usage Examples:**\n{examples_text}")
|
|
2369
|
+
|
|
2370
|
+
elif user_role == UserRole.DATABASE_ADMIN:
|
|
2371
|
+
# Performance and optimization focus
|
|
2372
|
+
if discovered_tables:
|
|
2373
|
+
sections.append(f"**Analyzed Tables:** {len(discovered_tables)} tables discovered")
|
|
2374
|
+
|
|
2375
|
+
if db_response.documentation:
|
|
2376
|
+
sections.append(db_response.documentation)
|
|
2377
|
+
if db_response.query:
|
|
2378
|
+
sections.append(f"**Query:**\n```sql\n{db_response.query}\n```")
|
|
2379
|
+
if db_response.execution_plan:
|
|
2380
|
+
sections.append(f"**Execution Plan:**\n```\n{db_response.execution_plan}\n```")
|
|
2381
|
+
if db_response.performance_metrics:
|
|
2382
|
+
metrics = "\n".join([f"- {k}: {v}" for k, v in db_response.performance_metrics.items()])
|
|
2383
|
+
sections.append(f"**Performance Metrics:**\n{metrics}")
|
|
2384
|
+
if db_response.optimization_tips:
|
|
2385
|
+
tips = "\n".join([f"- {tip}" for tip in db_response.optimization_tips])
|
|
2386
|
+
sections.append(f"**Optimization Suggestions:**\n{tips}")
|
|
2387
|
+
elif user_role in [UserRole.DATA_ANALYST, UserRole.DATA_SCIENTIST]:
|
|
2388
|
+
# Comprehensive format with data focus
|
|
2389
|
+
if db_response.query:
|
|
2390
|
+
sections.append(f"**SQL Query:**\n```sql\n{db_response.query}\n```")
|
|
2391
|
+
if db_response.data is not None:
|
|
2392
|
+
if isinstance(db_response.data, pd.DataFrame):
|
|
2393
|
+
sections.append(f"**Results:** {len(db_response.data)} records found")
|
|
2394
|
+
else:
|
|
2395
|
+
sections.append(f"**Results:** {db_response.row_count} records found")
|
|
2396
|
+
if db_response.documentation:
|
|
2397
|
+
sections.append(f"**Documentation:**\n{db_response.documentation}")
|
|
2398
|
+
if db_response.examples:
|
|
2399
|
+
examples_text = "\n".join([f"```sql\n{ex}\n```" for ex in db_response.examples])
|
|
2400
|
+
sections.append(f"**Usage Examples:**\n{examples_text}")
|
|
2401
|
+
if db_response.execution_plan:
|
|
2402
|
+
sections.append(f"**Execution Plan:**\n```\n{db_response.execution_plan}\n```")
|
|
2403
|
+
if db_response.performance_metrics:
|
|
2404
|
+
metrics = "\n".join([f"- {k}: {v}" for k, v in db_response.performance_metrics.items()])
|
|
2405
|
+
sections.append(f"**Performance Metrics:**\n{metrics}")
|
|
2406
|
+
if db_response.optimization_tips:
|
|
2407
|
+
tips = "\n".join([f"- {tip}" for tip in db_response.optimization_tips])
|
|
2408
|
+
sections.append(f"**Optimization Suggestions:**\n{tips}")
|
|
2409
|
+
else:
|
|
2410
|
+
# Default comprehensive format for DATA_ANALYST and DATA_SCIENTIST
|
|
2411
|
+
if discovered_tables:
|
|
2412
|
+
sections.append(
|
|
2413
|
+
f"**Schema Analysis:** Found {len(discovered_tables)} relevant tables"
|
|
2414
|
+
)
|
|
2415
|
+
return db_response.to_markdown()
|
|
2416
|
+
|
|
2417
|
+
return "\n\n".join(sections)
|
|
2418
|
+
|
|
2419
|
+
async def _format_response(
|
|
2420
|
+
self,
|
|
2421
|
+
query: str,
|
|
2422
|
+
db_response: DatabaseResponse,
|
|
2423
|
+
is_structured_output: bool,
|
|
2424
|
+
structured_output_class: Optional[Type[BaseModel]],
|
|
2425
|
+
llm_response: Optional[AIMessage],
|
|
2426
|
+
route: RouteDecision,
|
|
2427
|
+
output_format: Optional[str],
|
|
2428
|
+
discovered_tables: List[TableMetadata],
|
|
2429
|
+
**kwargs
|
|
2430
|
+
) -> AIMessage:
|
|
2431
|
+
"""Format final response based on route decision."""
|
|
2432
|
+
|
|
2433
|
+
if db_response.is_documentation and discovered_tables and not db_response.documentation:
|
|
2434
|
+
# Generate documentation on the fly
|
|
2435
|
+
db_response.documentation = await self._format_table_documentation(
|
|
2436
|
+
discovered_tables, route.user_role, query
|
|
2437
|
+
)
|
|
2438
|
+
|
|
2439
|
+
# Check if we have data to transform
|
|
2440
|
+
has_data = (
|
|
2441
|
+
OutputComponent.DATA_RESULTS in route.components or
|
|
2442
|
+
OutputComponent.DATAFRAME_OUTPUT in route.components
|
|
2443
|
+
) and db_response.data
|
|
2444
|
+
|
|
2445
|
+
if has_data and is_structured_output:
|
|
2446
|
+
# Handle DataFrame input
|
|
2447
|
+
output_data = self._to_structured_format(
|
|
2448
|
+
db_response.data,
|
|
2449
|
+
structured_output_class
|
|
2450
|
+
)
|
|
2451
|
+
response_text = ""
|
|
2452
|
+
# Generate response text based on format preference
|
|
2453
|
+
elif output_format == "markdown":
|
|
2454
|
+
response_text = db_response.to_markdown()
|
|
2455
|
+
if OutputComponent.DATAFRAME_OUTPUT in route.components and isinstance(db_response.data, pd.DataFrame):
|
|
2456
|
+
output_data = db_response.data
|
|
2457
|
+
elif OutputComponent.DATA_RESULTS in route.components:
|
|
2458
|
+
output_data = db_response.data
|
|
2459
|
+
elif output_format == "json":
|
|
2460
|
+
response_text = db_response.to_json()
|
|
2461
|
+
if OutputComponent.DATAFRAME_OUTPUT in route.components and isinstance(db_response.data, pd.DataFrame):
|
|
2462
|
+
output_data = db_response.data
|
|
2463
|
+
elif OutputComponent.DATA_RESULTS in route.components:
|
|
2464
|
+
output_data = db_response.data
|
|
2465
|
+
else:
|
|
2466
|
+
response_text = self._format_as_text(
|
|
2467
|
+
db_response,
|
|
2468
|
+
route.user_role,
|
|
2469
|
+
discovered_tables
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
# Prepare output data
|
|
2473
|
+
output_data = None
|
|
2474
|
+
if OutputComponent.DATAFRAME_OUTPUT in route.components and isinstance(db_response.data, pd.DataFrame):
|
|
2475
|
+
output_data = db_response.data
|
|
2476
|
+
elif OutputComponent.DATA_RESULTS in route.components:
|
|
2477
|
+
output_data = db_response.data
|
|
2478
|
+
|
|
2479
|
+
# Extract usage information from LLM response
|
|
2480
|
+
usage_info = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
|
2481
|
+
if llm_response and hasattr(llm_response, 'usage') and llm_response.usage:
|
|
2482
|
+
usage_info = llm_response.usage
|
|
2483
|
+
|
|
2484
|
+
# Extract model and provider info from LLM response if available
|
|
2485
|
+
model_name = getattr(self, '_llm_model', 'unknown')
|
|
2486
|
+
provider_name = str(getattr(self, '_llm', 'unknown'))
|
|
2487
|
+
|
|
2488
|
+
if llm_response:
|
|
2489
|
+
if hasattr(llm_response, 'model') and llm_response.model:
|
|
2490
|
+
model_name = llm_response.model
|
|
2491
|
+
if hasattr(llm_response, 'provider') and llm_response.provider:
|
|
2492
|
+
provider_name = str(llm_response.provider)
|
|
2493
|
+
|
|
2494
|
+
return AIMessage(
|
|
2495
|
+
input=query,
|
|
2496
|
+
response=response_text,
|
|
2497
|
+
output=output_data,
|
|
2498
|
+
model=model_name,
|
|
2499
|
+
provider=provider_name,
|
|
2500
|
+
metadata={
|
|
2501
|
+
"user_role": route.user_role.value,
|
|
2502
|
+
"components_included": [comp.name for comp in OutputComponent if comp in route.components],
|
|
2503
|
+
"intent": route.intent.value,
|
|
2504
|
+
"primary_schema": route.primary_schema,
|
|
2505
|
+
"sql_query": db_response.query,
|
|
2506
|
+
"row_count": db_response.row_count,
|
|
2507
|
+
"execution_time_ms": db_response.execution_time_ms,
|
|
2508
|
+
"has_dataframe": isinstance(db_response.data, pd.DataFrame),
|
|
2509
|
+
"data_format": "dataframe" if isinstance(db_response.data, pd.DataFrame) else "dict_list",
|
|
2510
|
+
"discovered_tables": [t.full_name for t in discovered_tables],
|
|
2511
|
+
"is_documentation": db_response.is_documentation,
|
|
2512
|
+
"llm_used": getattr(self, '_llm_model', 'unknown'),
|
|
2513
|
+
},
|
|
2514
|
+
usage=usage_info
|
|
2515
|
+
)
|
|
2516
|
+
|
|
2517
|
+
def _to_structured_format(self, data, output_format: Type) -> Union[List, object]:
|
|
2518
|
+
"""Convert data to structured format using Pydantic model."""
|
|
2519
|
+
if not output_format:
|
|
2520
|
+
return data
|
|
2521
|
+
|
|
2522
|
+
try:
|
|
2523
|
+
if isinstance(data, pd.DataFrame):
|
|
2524
|
+
data = data.to_dict('records')
|
|
2525
|
+
|
|
2526
|
+
if isinstance(data, list):
|
|
2527
|
+
return [
|
|
2528
|
+
output_format(**item) if isinstance(item, dict) else item for item in data
|
|
2529
|
+
]
|
|
2530
|
+
elif isinstance(data, dict):
|
|
2531
|
+
return output_format(**data)
|
|
2532
|
+
else:
|
|
2533
|
+
self.logger.warning(
|
|
2534
|
+
"Data is neither list nor dict; returning as-is."
|
|
2535
|
+
)
|
|
2536
|
+
return data
|
|
2537
|
+
except Exception as e:
|
|
2538
|
+
self.logger.error(f"Unexpected error during structuring: {e}")
|
|
2539
|
+
return data
|
|
2540
|
+
|
|
2541
|
+
def _extract_performance_metrics(
|
|
2542
|
+
self,
|
|
2543
|
+
query_plan: str,
|
|
2544
|
+
execution_time: float,
|
|
2545
|
+
plan_json: Optional[List[Dict]] = None
|
|
2546
|
+
) -> Dict[str, Any]:
|
|
2547
|
+
"""Extract performance metrics from query execution plan."""
|
|
2548
|
+
|
|
2549
|
+
metrics = {
|
|
2550
|
+
"execution_time_ms": execution_time,
|
|
2551
|
+
"estimated_cost": "N/A",
|
|
2552
|
+
"rows_examined": "N/A",
|
|
2553
|
+
"rows_planned": "N/A",
|
|
2554
|
+
"index_usage": "Unknown",
|
|
2555
|
+
"scan_types": [],
|
|
2556
|
+
"join_types": [],
|
|
2557
|
+
"buffer_metrics": {},
|
|
2558
|
+
"planning_time_ms": "N/A"
|
|
2559
|
+
}
|
|
2560
|
+
|
|
2561
|
+
# If we have JSON plan, extract from there (more accurate)
|
|
2562
|
+
if plan_json and isinstance(plan_json, list) and len(plan_json) > 0:
|
|
2563
|
+
try:
|
|
2564
|
+
plan_data = plan_json[0]
|
|
2565
|
+
|
|
2566
|
+
# Planning time
|
|
2567
|
+
if "Planning Time" in plan_data:
|
|
2568
|
+
metrics["planning_time_ms"] = plan_data["Planning Time"]
|
|
2569
|
+
|
|
2570
|
+
# Extract from main plan
|
|
2571
|
+
main_plan = plan_data.get("Plan", {})
|
|
2572
|
+
|
|
2573
|
+
# Cost information
|
|
2574
|
+
if "Total Cost" in main_plan:
|
|
2575
|
+
metrics["estimated_cost"] = main_plan["Total Cost"]
|
|
2576
|
+
|
|
2577
|
+
# Row information
|
|
2578
|
+
if "Actual Rows" in main_plan:
|
|
2579
|
+
metrics["rows_examined"] = main_plan["Actual Rows"]
|
|
2580
|
+
if "Plan Rows" in main_plan:
|
|
2581
|
+
metrics["rows_planned"] = main_plan["Plan Rows"]
|
|
2582
|
+
|
|
2583
|
+
# Buffer statistics
|
|
2584
|
+
buffer_stats = {}
|
|
2585
|
+
for key in ["Shared Hit Blocks", "Shared Read Blocks", "Temp Read Blocks", "Temp Written Blocks"]:
|
|
2586
|
+
if key in main_plan:
|
|
2587
|
+
buffer_stats[key.lower().replace(" ", "_")] = main_plan[key]
|
|
2588
|
+
if buffer_stats:
|
|
2589
|
+
metrics["buffer_metrics"] = buffer_stats
|
|
2590
|
+
|
|
2591
|
+
# Recursively analyze all nodes for scan/join types
|
|
2592
|
+
def analyze_node(node):
|
|
2593
|
+
node_type = node.get("Node Type", "")
|
|
2594
|
+
|
|
2595
|
+
# Scan types
|
|
2596
|
+
if "scan" in node_type.lower():
|
|
2597
|
+
scan_type = node_type
|
|
2598
|
+
metrics["scan_types"].append(scan_type)
|
|
2599
|
+
|
|
2600
|
+
# Index usage detection
|
|
2601
|
+
if "index" in node_type.lower():
|
|
2602
|
+
if "index only" in node_type.lower():
|
|
2603
|
+
metrics["index_usage"] = "Index-only scan"
|
|
2604
|
+
elif "bitmap" in node_type.lower():
|
|
2605
|
+
metrics["index_usage"] = "Bitmap index scan"
|
|
2606
|
+
else:
|
|
2607
|
+
metrics["index_usage"] = "Index scan"
|
|
2608
|
+
elif "seq" in node_type.lower():
|
|
2609
|
+
metrics["index_usage"] = "Sequential scan (no indexes)"
|
|
2610
|
+
|
|
2611
|
+
# Join types
|
|
2612
|
+
if "join" in node_type.lower():
|
|
2613
|
+
metrics["join_types"].append(node_type)
|
|
2614
|
+
|
|
2615
|
+
# Process child plans
|
|
2616
|
+
if "Plans" in node:
|
|
2617
|
+
for child_plan in node["Plans"]:
|
|
2618
|
+
analyze_node(child_plan)
|
|
2619
|
+
|
|
2620
|
+
analyze_node(main_plan)
|
|
2621
|
+
|
|
2622
|
+
# Remove duplicates
|
|
2623
|
+
metrics["scan_types"] = list(set(metrics["scan_types"]))
|
|
2624
|
+
metrics["join_types"] = list(set(metrics["join_types"]))
|
|
2625
|
+
|
|
2626
|
+
return metrics
|
|
2627
|
+
|
|
2628
|
+
except Exception as e:
|
|
2629
|
+
self.logger.error(f"Error extracting metrics from JSON plan: {e}")
|
|
2630
|
+
# Fall back to text parsing
|
|
2631
|
+
|
|
2632
|
+
# Fallback: Extract from text plan
|
|
2633
|
+
if not query_plan:
|
|
2634
|
+
return metrics
|
|
2635
|
+
|
|
2636
|
+
lines = query_plan.split('\n')
|
|
2637
|
+
for line in lines:
|
|
2638
|
+
line_lower = line.lower()
|
|
2639
|
+
|
|
2640
|
+
# Extract cost information
|
|
2641
|
+
if 'cost:' in line_lower:
|
|
2642
|
+
cost_match = re.search(r'cost:\s*([\d.]+)', line)
|
|
2643
|
+
if cost_match:
|
|
2644
|
+
metrics["estimated_cost"] = float(cost_match.group(1))
|
|
2645
|
+
|
|
2646
|
+
# Extract row information
|
|
2647
|
+
if 'rows:' in line_lower:
|
|
2648
|
+
rows_match = re.search(r'rows:\s*(\d+)', line)
|
|
2649
|
+
if rows_match:
|
|
2650
|
+
metrics["rows_examined"] = int(rows_match.group(1))
|
|
2651
|
+
|
|
2652
|
+
# Detect scan types
|
|
2653
|
+
if 'seq scan' in line_lower:
|
|
2654
|
+
metrics["scan_types"].append("Sequential Scan")
|
|
2655
|
+
metrics["index_usage"] = "No indexes used"
|
|
2656
|
+
elif 'index scan' in line_lower:
|
|
2657
|
+
metrics["scan_types"].append("Index Scan")
|
|
2658
|
+
metrics["index_usage"] = "Indexes used"
|
|
2659
|
+
elif 'index only scan' in line_lower:
|
|
2660
|
+
metrics["scan_types"].append("Index Only Scan")
|
|
2661
|
+
metrics["index_usage"] = "Index-only access"
|
|
2662
|
+
elif 'bitmap heap scan' in line_lower:
|
|
2663
|
+
metrics["scan_types"].append("Bitmap Heap Scan")
|
|
2664
|
+
metrics["index_usage"] = "Bitmap index used"
|
|
2665
|
+
|
|
2666
|
+
# Detect join types
|
|
2667
|
+
if 'nested loop' in line_lower:
|
|
2668
|
+
metrics["join_types"].append("Nested Loop")
|
|
2669
|
+
elif 'hash join' in line_lower:
|
|
2670
|
+
metrics["join_types"].append("Hash Join")
|
|
2671
|
+
elif 'merge join' in line_lower:
|
|
2672
|
+
metrics["join_types"].append("Merge Join")
|
|
2673
|
+
|
|
2674
|
+
# Remove duplicates
|
|
2675
|
+
metrics["scan_types"] = list(set(metrics["scan_types"]))
|
|
2676
|
+
metrics["join_types"] = list(set(metrics["join_types"]))
|
|
2677
|
+
|
|
2678
|
+
return metrics
|
|
2679
|
+
|
|
2680
|
+
async def _generate_optimization_tips(
|
|
2681
|
+
self,
|
|
2682
|
+
sql_query: str,
|
|
2683
|
+
query_plan: str,
|
|
2684
|
+
metadata_context: str,
|
|
2685
|
+
context: Optional[str] = None,
|
|
2686
|
+
plan_json: Optional[List[Dict]] = None # Add JSON plan data
|
|
2687
|
+
) -> Tuple[List[str], Optional[AIMessage]]:
|
|
2688
|
+
"""
|
|
2689
|
+
LLM-based optimization tips with better parsing.
|
|
2690
|
+
"""
|
|
2691
|
+
if not query_plan:
|
|
2692
|
+
return ["Enable query plan analysis for optimization suggestions"], None
|
|
2693
|
+
|
|
2694
|
+
self.logger.debug("🔧 Generating LLM-based optimization tips...")
|
|
2695
|
+
|
|
2696
|
+
# Enhanced prompt with better formatting instructions
|
|
2697
|
+
optimization_prompt = f"""
|
|
2698
|
+
You are a PostgreSQL performance tutor helping developers understand and fix query performance issues.
|
|
2699
|
+
|
|
2700
|
+
**SQL Query:**
|
|
2701
|
+
```sql
|
|
2702
|
+
{sql_query}
|
|
2703
|
+
```
|
|
2704
|
+
|
|
2705
|
+
**Execution Plan:**
|
|
2706
|
+
```
|
|
2707
|
+
{query_plan}
|
|
2708
|
+
```
|
|
2709
|
+
* If available, here is the JSON representation of the execution plan for more accurate analysis: *
|
|
2710
|
+
```json
|
|
2711
|
+
{plan_json}
|
|
2712
|
+
```
|
|
2713
|
+
|
|
2714
|
+
**Available Schema Context:**
|
|
2715
|
+
{metadata_context[:1000] if metadata_context else 'No schema context available'}
|
|
2716
|
+
|
|
2717
|
+
{context}
|
|
2718
|
+
|
|
2719
|
+
**EDUCATIONAL MISSION:**
|
|
2720
|
+
Your goal is to teach PostgreSQL optimization concepts while providing actionable solutions. Each recommendation should:
|
|
2721
|
+
1. EXPLAIN the underlying PostgreSQL concept (why this matters)
|
|
2722
|
+
2. IDENTIFY the specific issue in this query
|
|
2723
|
+
3. PROVIDE the exact SQL commands to fix it
|
|
2724
|
+
4. EXPLAIN what the fix accomplishes
|
|
2725
|
+
|
|
2726
|
+
**RESPONSE FORMAT:**
|
|
2727
|
+
- Start each tip with an emoji and descriptive title
|
|
2728
|
+
- Include a brief explanation of the PostgreSQL concept
|
|
2729
|
+
- Provide specific SQL commands with actual table/column names from the query
|
|
2730
|
+
- Explain the expected performance impact
|
|
2731
|
+
|
|
2732
|
+
**EXAMPLE GOOD TIP:**
|
|
2733
|
+
📊 **Update Table Statistics for Better Query Planning**
|
|
2734
|
+
|
|
2735
|
+
**What's happening:** PostgreSQL's query planner uses table statistics to estimate how many rows operations will return. When these statistics are outdated, the planner makes poor decisions (like choosing slow sequential scans over fast index scans).
|
|
2736
|
+
|
|
2737
|
+
**The issue:** Your execution plan shows estimated 42M rows but actual 5 rows - this massive discrepancy indicates stale statistics on the `form_data` table.
|
|
2738
|
+
|
|
2739
|
+
**Fix this with:**
|
|
2740
|
+
```sql
|
|
2741
|
+
-- Update statistics for the specific table
|
|
2742
|
+
ANALYZE hisense.form_data;
|
|
2743
|
+
|
|
2744
|
+
-- Or update all tables in the schema
|
|
2745
|
+
ANALYZE;
|
|
2746
|
+
|
|
2747
|
+
-- Check when statistics were last updated
|
|
2748
|
+
SELECT schemaname, tablename, last_analyze, last_autoanalyze
|
|
2749
|
+
FROM pg_stat_user_tables
|
|
2750
|
+
WHERE tablename = 'form_data';
|
|
2751
|
+
```
|
|
2752
|
+
|
|
2753
|
+
**Why this helps:** Fresh statistics allow PostgreSQL to choose optimal execution paths, potentially changing sequential scans to index scans and improving query performance by orders of magnitude.
|
|
2754
|
+
|
|
2755
|
+
**FOCUS AREAS FOR THIS QUERY:**
|
|
2756
|
+
Based on the execution plan, prioritize recommendations about:
|
|
2757
|
+
- Statistics accuracy (row estimate discrepancies)
|
|
2758
|
+
- Index usage and creation with specific column combinations
|
|
2759
|
+
- Query structure improvements with rewritten SQL examples
|
|
2760
|
+
- Buffer usage and I/O optimization
|
|
2761
|
+
- Join strategy improvements (if applicable)
|
|
2762
|
+
|
|
2763
|
+
**IMPORTANT REQUIREMENTS:**
|
|
2764
|
+
- Always include the actual SQL commands to implement your suggestions
|
|
2765
|
+
- Use the real table and column names from the provided query
|
|
2766
|
+
- Explain PostgreSQL concepts in accessible terms
|
|
2767
|
+
- Focus on the most impactful optimizations first (biggest performance gains)
|
|
2768
|
+
- Limit to 3-4 high-impact recommendations
|
|
2769
|
+
|
|
2770
|
+
Provide specific, educational recommendations with concrete implementation steps:
|
|
2771
|
+
"""
|
|
2772
|
+
try:
|
|
2773
|
+
# Call LLM for optimization analysis
|
|
2774
|
+
async with self._llm as client:
|
|
2775
|
+
llm_response = await client.ask(
|
|
2776
|
+
prompt=optimization_prompt,
|
|
2777
|
+
temperature=0.1,
|
|
2778
|
+
max_tokens=4096,
|
|
2779
|
+
max_retries=2,
|
|
2780
|
+
use_tools=False,
|
|
2781
|
+
stateless=True
|
|
2782
|
+
)
|
|
2783
|
+
|
|
2784
|
+
response_text = str(llm_response.output) if llm_response.output else str(llm_response.response)
|
|
2785
|
+
self.logger.debug(f"🔧 LLM Optimization Response: {response_text[:200]}...")
|
|
2786
|
+
|
|
2787
|
+
# Enhanced parsing logic
|
|
2788
|
+
tips = []
|
|
2789
|
+
tips = self._parse_tips(response_text)
|
|
2790
|
+
if tips:
|
|
2791
|
+
self.logger.info(f"✅ Generated {len(tips)} optimization tips")
|
|
2792
|
+
return tips, llm_response
|
|
2793
|
+
except Exception as e:
|
|
2794
|
+
self.logger.error(f"LLM Optimization Tips Error: {e}")
|
|
2795
|
+
|
|
2796
|
+
# Fallback to basic analysis if LLM fails
|
|
2797
|
+
return self._generate_basic_optimization_tips(
|
|
2798
|
+
sql_query,
|
|
2799
|
+
query_plan
|
|
2800
|
+
), None
|
|
2801
|
+
|
|
2802
|
+
def _parse_tips(self, response_text: str) -> List[str]:
|
|
2803
|
+
"""Parse performance tips with multi-line content."""
|
|
2804
|
+
tips = []
|
|
2805
|
+
current_tip = []
|
|
2806
|
+
in_tip = False
|
|
2807
|
+
|
|
2808
|
+
lines = response_text.split('\n')
|
|
2809
|
+
|
|
2810
|
+
for line in lines:
|
|
2811
|
+
line = line.strip()
|
|
2812
|
+
|
|
2813
|
+
# Start of a new tip (emoji + title)
|
|
2814
|
+
if (line and any(emoji in line[:10] for emoji in ['📊', '⚡', '🔗', '💾', '🔧', '📈', '🎯', '🔍'])
|
|
2815
|
+
and ('**' in line or line.startswith(('📊', '⚡', '🔗', '💾', '🔧', '📈', '🎯', '🔍')))):
|
|
2816
|
+
|
|
2817
|
+
# Save previous tip if exists
|
|
2818
|
+
if current_tip:
|
|
2819
|
+
tip_text = '\n'.join(current_tip).strip()
|
|
2820
|
+
if len(tip_text) > 50: # Only keep substantial tips
|
|
2821
|
+
tips.append(tip_text)
|
|
2822
|
+
|
|
2823
|
+
# Start new tip
|
|
2824
|
+
current_tip = [line]
|
|
2825
|
+
in_tip = True
|
|
2826
|
+
|
|
2827
|
+
elif in_tip and line:
|
|
2828
|
+
# Continue building current tip - KEEP ALL CONTENT
|
|
2829
|
+
current_tip.append(line)
|
|
2830
|
+
|
|
2831
|
+
elif in_tip and not line:
|
|
2832
|
+
# Empty line - add it to preserve formatting
|
|
2833
|
+
current_tip.append('')
|
|
2834
|
+
|
|
2835
|
+
# Add the last tip
|
|
2836
|
+
if current_tip:
|
|
2837
|
+
tip_text = '\n'.join(current_tip).strip()
|
|
2838
|
+
if len(tip_text) > 50:
|
|
2839
|
+
tips.append(tip_text)
|
|
2840
|
+
|
|
2841
|
+
# Return all tips without truncation - developers need complete information
|
|
2842
|
+
return tips
|
|
2843
|
+
|
|
2844
|
+
def _generate_basic_optimization_tips(self, sql_query: str, query_plan: str) -> List[str]:
|
|
2845
|
+
"""Fallback basic optimization tips using pattern matching."""
|
|
2846
|
+
tips = []
|
|
2847
|
+
plan_lower = query_plan.lower()
|
|
2848
|
+
query_lower = sql_query.lower() if sql_query else ""
|
|
2849
|
+
|
|
2850
|
+
# Sequential scan detection
|
|
2851
|
+
if 'seq scan' in plan_lower:
|
|
2852
|
+
tips.append("⚡ Consider adding indexes on frequently filtered columns to avoid sequential scans")
|
|
2853
|
+
|
|
2854
|
+
# Large sort operations
|
|
2855
|
+
if 'sort' in plan_lower:
|
|
2856
|
+
tips.append("📈 Large sort operation detected - consider adding indexes for ORDER BY columns")
|
|
2857
|
+
|
|
2858
|
+
# Nested loop joins
|
|
2859
|
+
if 'nested loop' in plan_lower and 'join' in query_lower:
|
|
2860
|
+
tips.append("🔗 Nested loop joins detected - ensure join columns are indexed")
|
|
2861
|
+
|
|
2862
|
+
# Query structure tips
|
|
2863
|
+
if query_lower:
|
|
2864
|
+
if 'select *' in query_lower:
|
|
2865
|
+
tips.append("📝 Avoid SELECT * - specify only needed columns for better performance")
|
|
2866
|
+
|
|
2867
|
+
return tips or ["✅ Query appears to be well-optimized"]
|
|
2868
|
+
|
|
2869
|
+
def _extract_table_names_from_metadata(self, metadata_context: str) -> List[str]:
|
|
2870
|
+
"""Extract table names from metadata context."""
|
|
2871
|
+
if not metadata_context:
|
|
2872
|
+
return []
|
|
2873
|
+
|
|
2874
|
+
# Look for table references in YAML context
|
|
2875
|
+
table_matches = re.findall(r'table:\s+\w+\.(\w+)', metadata_context)
|
|
2876
|
+
return list(set(table_matches))[:5] # Limit to 5 unique tables
|
|
2877
|
+
|
|
2878
|
+
async def _generate_examples(
|
|
2879
|
+
self,
|
|
2880
|
+
query: str,
|
|
2881
|
+
metadata_context: str,
|
|
2882
|
+
discovered_tables: List[TableMetadata],
|
|
2883
|
+
schema_name: str
|
|
2884
|
+
) -> List[str]:
|
|
2885
|
+
"""Generate usage examples based on available schema metadata."""
|
|
2886
|
+
|
|
2887
|
+
examples = []
|
|
2888
|
+
|
|
2889
|
+
if discovered_tables:
|
|
2890
|
+
# Generate examples for each discovered table (limit to 2 for brevity)
|
|
2891
|
+
for i, table in enumerate(discovered_tables[:2]):
|
|
2892
|
+
table_examples = [
|
|
2893
|
+
f"-- Examples for table: {table.full_name}",
|
|
2894
|
+
f"SELECT * FROM {table.full_name} LIMIT 10;",
|
|
2895
|
+
"",
|
|
2896
|
+
f"SELECT COUNT(*) FROM {table.full_name};",
|
|
2897
|
+
""
|
|
2898
|
+
]
|
|
2899
|
+
# Add column-specific examples if columns are available
|
|
2900
|
+
if table.columns:
|
|
2901
|
+
# Find interesting columns (non-id, non-timestamp)
|
|
2902
|
+
interesting_cols = [
|
|
2903
|
+
col['name'] for col in table.columns
|
|
2904
|
+
if not col['name'].lower().endswith(('_id', 'id'))
|
|
2905
|
+
and col['type'].lower() not in ('timestamp', 'timestamptz')
|
|
2906
|
+
][:5] # Limit to 5 columns
|
|
2907
|
+
if interesting_cols:
|
|
2908
|
+
col_list = ', '.join(interesting_cols)
|
|
2909
|
+
table_examples.extend([
|
|
2910
|
+
f"SELECT {col_list} FROM {table.full_name} WHERE {interesting_cols[0]} IS NOT NULL LIMIT 2;",
|
|
2911
|
+
""
|
|
2912
|
+
])
|
|
2913
|
+
examples.extend(table_examples)
|
|
2914
|
+
# Add schema exploration examples
|
|
2915
|
+
examples.extend([
|
|
2916
|
+
"-- Schema exploration",
|
|
2917
|
+
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}';",
|
|
2918
|
+
"",
|
|
2919
|
+
"-- Find tables with specific column patterns",
|
|
2920
|
+
f"SELECT table_name, column_name FROM information_schema.columns "
|
|
2921
|
+
f"WHERE table_schema = '{schema_name}' AND column_name LIKE '%name%';"
|
|
2922
|
+
])
|
|
2923
|
+
return ["\n".join(examples)]
|
|
2924
|
+
|
|
2925
|
+
# Extract table names from metadata context
|
|
2926
|
+
tables = self._extract_table_names_from_metadata(metadata_context)
|
|
2927
|
+
|
|
2928
|
+
if not tables:
|
|
2929
|
+
# Fallback examples
|
|
2930
|
+
return [
|
|
2931
|
+
f"SELECT * FROM {schema_name}.table_name LIMIT 10;",
|
|
2932
|
+
f"SELECT COUNT(*) FROM {schema_name}.table_name;",
|
|
2933
|
+
f"DESCRIBE {schema_name}.table_name;"
|
|
2934
|
+
]
|
|
2935
|
+
|
|
2936
|
+
# Generate examples for available tables
|
|
2937
|
+
for table in tables[:2]: # Limit to 2 tables to avoid clutter
|
|
2938
|
+
table_examples = [
|
|
2939
|
+
f"-- Basic data retrieval from {table}",
|
|
2940
|
+
f"SELECT * FROM {schema_name}.{table} LIMIT 10;",
|
|
2941
|
+
f"",
|
|
2942
|
+
f"-- Count records in {table}",
|
|
2943
|
+
f"SELECT COUNT(*) FROM {schema_name}.{table};",
|
|
2944
|
+
f"",
|
|
2945
|
+
f"-- Get table structure",
|
|
2946
|
+
f"\\d {schema_name}.{table};"
|
|
2947
|
+
]
|
|
2948
|
+
examples.extend(table_examples)
|
|
2949
|
+
|
|
2950
|
+
# Add schema exploration examples
|
|
2951
|
+
examples.extend([
|
|
2952
|
+
"",
|
|
2953
|
+
"-- List all tables in schema",
|
|
2954
|
+
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}';",
|
|
2955
|
+
"",
|
|
2956
|
+
"-- Find tables containing specific column",
|
|
2957
|
+
f"SELECT table_name FROM information_schema.columns WHERE table_schema = '{schema_name}' AND column_name LIKE '%name%';"
|
|
2958
|
+
])
|
|
2959
|
+
|
|
2960
|
+
return ["\n".join(examples)]
|
|
2961
|
+
|
|
2962
|
+
def _create_error_response(
|
|
2963
|
+
self,
|
|
2964
|
+
query: str,
|
|
2965
|
+
error: Exception,
|
|
2966
|
+
user_role
|
|
2967
|
+
) -> 'AIMessage':
|
|
2968
|
+
"""Create enhanced error response with role-appropriate information."""
|
|
2969
|
+
error_msg = f"Error processing database query: {str(error)}"
|
|
2970
|
+
|
|
2971
|
+
# Role-specific error information
|
|
2972
|
+
if user_role.value == 'developer':
|
|
2973
|
+
error_msg += f"\n\n**Debug Information:**"
|
|
2974
|
+
error_msg += f"\n- Error Type: {type(error).__name__}"
|
|
2975
|
+
error_msg += f"\n- Primary Schema: {self.primary_schema}"
|
|
2976
|
+
error_msg += f"\n- Allowed Schemas: {', '.join(self.allowed_schemas)}"
|
|
2977
|
+
error_msg += f"\n- Tools Available: {len(self.tool_manager.get_tools())}"
|
|
2978
|
+
|
|
2979
|
+
elif user_role.value == 'database_admin':
|
|
2980
|
+
error_msg += f"\n\n**Technical Details:**"
|
|
2981
|
+
error_msg += f"\n- Error: {type(error).__name__}: {str(error)}"
|
|
2982
|
+
error_msg += f"\n- Schema Context: {self.primary_schema}"
|
|
2983
|
+
|
|
2984
|
+
else:
|
|
2985
|
+
# Simplified error for business users and analysts
|
|
2986
|
+
error_msg = f"Unable to process your request. Please try rephrasing your query or contact support."
|
|
2987
|
+
|
|
2988
|
+
return AIMessage(
|
|
2989
|
+
input=query,
|
|
2990
|
+
response=error_msg,
|
|
2991
|
+
output=None,
|
|
2992
|
+
model="error_handler",
|
|
2993
|
+
provider="system",
|
|
2994
|
+
metadata={
|
|
2995
|
+
"error_type": type(error).__name__,
|
|
2996
|
+
"error_message": str(error),
|
|
2997
|
+
"user_role": user_role.value,
|
|
2998
|
+
"primary_schema": self.primary_schema
|
|
2999
|
+
},
|
|
3000
|
+
usage=CompletionUsage(
|
|
3001
|
+
prompt_tokens=0,
|
|
3002
|
+
completion_tokens=0,
|
|
3003
|
+
total_tokens=0
|
|
3004
|
+
)
|
|
3005
|
+
)
|
|
3006
|
+
|
|
3007
|
+
async def _update_conversation_memory(
|
|
3008
|
+
self,
|
|
3009
|
+
user_id: str,
|
|
3010
|
+
session_id: str,
|
|
3011
|
+
user_prompt: str,
|
|
3012
|
+
response: AIMessage,
|
|
3013
|
+
user_context: Optional[str],
|
|
3014
|
+
vector_metadata: Dict[str, Any],
|
|
3015
|
+
conversation_history
|
|
3016
|
+
):
|
|
3017
|
+
"""Update conversation memory with the current interaction."""
|
|
3018
|
+
if not self.conversation_memory or not conversation_history:
|
|
3019
|
+
return
|
|
3020
|
+
|
|
3021
|
+
try:
|
|
3022
|
+
assistant_content = str(response.output) if response.output is not None else (response.response or "")
|
|
3023
|
+
|
|
3024
|
+
# Extract tools used
|
|
3025
|
+
tools_used = []
|
|
3026
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
|
3027
|
+
tools_used = [tool_call.name for tool_call in response.tool_calls]
|
|
3028
|
+
|
|
3029
|
+
turn = ConversationTurn(
|
|
3030
|
+
turn_id=str(uuid.uuid4()),
|
|
3031
|
+
user_id=user_id,
|
|
3032
|
+
user_message=user_prompt,
|
|
3033
|
+
assistant_response=assistant_content,
|
|
3034
|
+
metadata={
|
|
3035
|
+
'user_context': user_context,
|
|
3036
|
+
'tools_used': tools_used,
|
|
3037
|
+
'primary_schema': self.primary_schema,
|
|
3038
|
+
'tables_referenced': vector_metadata.get('tables_referenced', []),
|
|
3039
|
+
'sources_used': vector_metadata.get('sources', []),
|
|
3040
|
+
'has_sql_execution': bool(response.metadata and response.metadata.get('sql_executed')),
|
|
3041
|
+
'execution_success': response.metadata.get('execution_success') if response.metadata else None
|
|
3042
|
+
}
|
|
3043
|
+
)
|
|
3044
|
+
|
|
3045
|
+
chatbot_key = getattr(self, 'chatbot_id', None)
|
|
3046
|
+
if chatbot_key is not None:
|
|
3047
|
+
chatbot_key = str(chatbot_key)
|
|
3048
|
+
await self.conversation_memory.add_turn(
|
|
3049
|
+
user_id,
|
|
3050
|
+
session_id,
|
|
3051
|
+
turn,
|
|
3052
|
+
chatbot_id=chatbot_key
|
|
3053
|
+
)
|
|
3054
|
+
self.logger.debug(
|
|
3055
|
+
f"Updated conversation memory for session {session_id}"
|
|
3056
|
+
)
|
|
3057
|
+
|
|
3058
|
+
except Exception as e:
|
|
3059
|
+
self.logger.error(
|
|
3060
|
+
f"Failed to update conversation memory: {e}"
|
|
3061
|
+
)
|
|
3062
|
+
|
|
3063
|
+
async def cleanup(self) -> None:
|
|
3064
|
+
"""Cleanup database and parent resources."""
|
|
3065
|
+
try:
|
|
3066
|
+
# Close database engine
|
|
3067
|
+
if self.engine:
|
|
3068
|
+
await self.engine.dispose()
|
|
3069
|
+
self.logger.debug("Database engine disposed")
|
|
3070
|
+
except Exception as e:
|
|
3071
|
+
self.logger.error(f"Error during DB agent cleanup: {e}")
|