ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
parrot/tools/db.py ADDED
@@ -0,0 +1,1800 @@
1
+ """
2
+ Unified Database Tool for AI-Parrot
3
+
4
+ Consolidates schema extraction, knowledge base building, query generation,
5
+ validation, and execution into a single, powerful database interface.
6
+ """
7
+ from __future__ import annotations
8
+ from typing import Dict, List, Optional, Any, Union, Literal, Tuple
9
+ import re
10
+ import asyncio
11
+ import json
12
+ import hashlib
13
+ from datetime import datetime, timedelta, timezone
14
+ from enum import Enum
15
+ import pandas as pd
16
+ from pydantic import (
17
+ BaseModel,
18
+ Field,
19
+ field_validator,
20
+ model_validator
21
+ )
22
+ from asyncdb import AsyncDB
23
+ from .abstract import (
24
+ AbstractTool,
25
+ ToolResult,
26
+ AbstractToolArgsSchema
27
+ )
28
+ from ..stores.abstract import AbstractStore
29
+ from ..clients.base import AbstractClient
30
+ from ..clients.factory import LLMFactory
31
+ from ..models import AIMessage
32
+
33
+
34
+ class DatabaseFlavor(str, Enum):
35
+ """Supported database flavors."""
36
+ POSTGRESQL = "postgresql"
37
+ MYSQL = "mysql"
38
+ SQLSERVER = "sqlserver"
39
+ BIGQUERY = "bigquery"
40
+ INFLUXDB = "influxdb"
41
+ CASSANDRA = "cassandra"
42
+ MONGODB = "mongodb"
43
+ ELASTICSEARCH = "elasticsearch"
44
+ SQLITE = "sqlite"
45
+ DUCKDB = "duckdb"
46
+
47
+
48
+ class QueryType(str, Enum):
49
+ """Supported query types."""
50
+ SELECT = "SELECT"
51
+ INSERT = "INSERT"
52
+ UPDATE = "UPDATE"
53
+ DELETE = "DELETE"
54
+ CREATE = "CREATE"
55
+ ALTER = "ALTER"
56
+ DROP = "DROP"
57
+
58
+
59
+ class OutputFormat(str, Enum):
60
+ """Supported output formats."""
61
+ PANDAS = "pandas"
62
+ JSON = "json"
63
+ DICT = "dict"
64
+ CSV = "csv"
65
+ STRUCTURED = "structured" # Uses Pydantic models
66
+
67
+
68
+ class SchemaMetadata(BaseModel):
69
+ """Metadata for a database schema."""
70
+ schema_name: str
71
+ tables: List[Dict[str, Any]]
72
+ views: List[Dict[str, Any]]
73
+ functions: List[Dict[str, Any]]
74
+ procedures: List[Dict[str, Any]]
75
+ indexes: List[Dict[str, Any]]
76
+ constraints: List[Dict[str, Any]]
77
+ last_updated: datetime
78
+ database_flavor: DatabaseFlavor
79
+
80
+
81
+ class QueryValidationResult(BaseModel):
82
+ """Result of query validation."""
83
+ is_valid: bool
84
+ query_type: Optional[QueryType]
85
+ affected_tables: List[str]
86
+ estimated_cost: Optional[float]
87
+ warnings: List[str]
88
+ errors: List[str]
89
+ security_checks: Dict[str, bool]
90
+
91
+
92
+ class DatabaseToolArgs(AbstractToolArgsSchema):
93
+ """Arguments for the unified database tool."""
94
+
95
+ # Query specification
96
+ natural_language_query: Optional[str] = Field(
97
+ None, description="Natural language description of what you want to query"
98
+ )
99
+ sql_query: Optional[str] = Field(
100
+ None, description="Direct SQL query to execute"
101
+ )
102
+
103
+ # Database connection
104
+ database_flavor: DatabaseFlavor = Field(
105
+ DatabaseFlavor.POSTGRESQL, description="Type of database to connect to"
106
+ )
107
+ connection_params: Optional[Dict[str, Any]] = Field(
108
+ None, description="Database connection parameters"
109
+ )
110
+ schema_names: List[str] = Field(
111
+ default=["public"], description="Schema names to work with"
112
+ )
113
+
114
+ # Operation modes
115
+ operation: Literal[
116
+ "schema_extract", "query_generate", "query_validate",
117
+ "query_execute", "full_pipeline", "explain_query"
118
+ ] = Field(
119
+ "full_pipeline", description="What operation to perform"
120
+ )
121
+
122
+ # Query options
123
+ max_rows: int = Field(1000, description="Maximum rows to return")
124
+ timeout_seconds: int = Field(300, description="Query timeout")
125
+ dry_run: bool = Field(False, description="Validate without executing")
126
+
127
+ # Output options
128
+ output_format: OutputFormat = Field(
129
+ OutputFormat.PANDAS, description="Format for query results"
130
+ )
131
+ structured_output_schema: Optional[Dict[str, Any]] = Field(
132
+ None, description="Pydantic schema for structured outputs"
133
+ )
134
+
135
+ # Knowledge base options
136
+ update_knowledge_base: bool = Field(
137
+ True, description="Whether to update schema knowledge base"
138
+ )
139
+ cache_duration_hours: int = Field(
140
+ 24, description="How long to cache schema metadata"
141
+ )
142
+
143
+ @model_validator(mode='after')
144
+ def validate_query_input(self) -> 'DatabaseToolArgs':
145
+ # Ensure at least one query type is provided for query operations
146
+ if self.operation in ['query_generate', 'query_execute', 'full_pipeline', 'explain_query']:
147
+ if not self.natural_language_query and not self.sql_query:
148
+ raise ValueError("Either natural_language_query or sql_query must be provided")
149
+ return self
150
+
151
+
152
+ class DatabaseTool(AbstractTool):
153
+ """
154
+ Unified Database Tool that handles the complete database interaction pipeline:
155
+
156
+ 1. Schema Discovery: Extract and cache table schemas from any supported database
157
+ 2. Knowledge Base Building: Store schema metadata in vector store for RAG
158
+ 3. Query Generation: Convert natural language to database-specific queries
159
+ 4. Query Validation: Syntax checking, security validation, cost estimation
160
+ 5. Query Execution: Safe execution with proper error handling
161
+ 6. Structured Output: Format results according to specified schemas
162
+
163
+ This tool consolidates the functionality of SchemaTool, DatabaseQueryTool,
164
+ and SQLAgent into a single, cohesive interface.
165
+ """
166
+
167
+ name = "database_tool"
168
+ description = """Unified database tool for schema discovery, query generation,
169
+ validation, and execution across multiple database types"""
170
+ args_schema = DatabaseToolArgs
171
+
172
+ def __init__(
173
+ self,
174
+ knowledge_store: Optional[AbstractStore] = None,
175
+ default_connection_params: Optional[Dict[DatabaseFlavor, Dict]] = None,
176
+ enable_query_caching: bool = True,
177
+ llm: Optional[Union[AbstractClient, str]] = None,
178
+ **kwargs
179
+ ):
180
+ """
181
+ Initialize the unified database tool.
182
+
183
+ Args:
184
+ knowledge_store: Vector store for schema metadata and RAG
185
+ default_connection_params: Default connection parameters per database type
186
+ enable_query_caching: Whether to cache query results
187
+ llm: LLM to use for query generation and validation
188
+ """
189
+ super().__init__(**kwargs)
190
+
191
+ self.knowledge_store = knowledge_store
192
+ self.default_connection_params = default_connection_params or {}
193
+ self.enable_query_caching = enable_query_caching
194
+
195
+ # Initialize LLM
196
+ if isinstance(llm, str):
197
+ self.llm = LLMFactory.create(llm)
198
+ else:
199
+ self.llm = llm
200
+
201
+ # Cache for schema metadata and database connections
202
+ self._schema_cache: Dict[str, Tuple[SchemaMetadata, datetime]] = {}
203
+ self._connection_cache: Dict[str, AsyncDB] = {}
204
+
205
+ # Database-specific query generators and validators
206
+ self._query_generators = {}
207
+ self._query_validators = {}
208
+
209
+ self._setup_database_handlers()
210
+
211
+ def _setup_database_handlers(self):
212
+ """Initialize database-specific handlers for different flavors."""
213
+ # This would be expanded to include handlers for each database type
214
+ self._query_generators = {
215
+ DatabaseFlavor.POSTGRESQL: self._generate_postgresql_query,
216
+ DatabaseFlavor.MYSQL: self._generate_mysql_query,
217
+ DatabaseFlavor.BIGQUERY: self._generate_bigquery_query,
218
+ # Add more database-specific generators...
219
+ }
220
+
221
+ self._query_validators = {
222
+ DatabaseFlavor.POSTGRESQL: self._validate_postgresql_query,
223
+ DatabaseFlavor.MYSQL: self._validate_mysql_query,
224
+ DatabaseFlavor.BIGQUERY: self._validate_bigquery_query,
225
+ # Add more database-specific validators...
226
+ }
227
+
228
+ def _clean_sql(self, sql_query: str) -> str:
229
+ """Clean SQL query from markdown formatting."""
230
+ if not sql_query:
231
+ return ""
232
+ # Remove markdown code blocks
233
+ clean_query = re.sub(r'```\w*\n?', '', sql_query)
234
+ clean_query = clean_query.replace('```', '')
235
+ return clean_query.strip()
236
+
237
+ async def _execute(
238
+ self,
239
+ natural_language_query: Optional[str] = None,
240
+ sql_query: Optional[str] = None,
241
+ database_flavor: DatabaseFlavor = DatabaseFlavor.POSTGRESQL,
242
+ connection_params: Optional[Dict[str, Any]] = None,
243
+ schema_names: List[str] = ["public"],
244
+ operation: str = "full_pipeline",
245
+ max_rows: int = 1000,
246
+ timeout_seconds: int = 300,
247
+ dry_run: bool = False,
248
+ output_format: OutputFormat = OutputFormat.PANDAS,
249
+ structured_output_schema: Optional[Dict[str, Any]] = None,
250
+ update_knowledge_base: bool = True,
251
+ cache_duration_hours: int = 24,
252
+ **kwargs
253
+ ) -> ToolResult:
254
+ """
255
+ Execute the unified database tool pipeline.
256
+
257
+ The method routes to different sub-operations based on the operation parameter,
258
+ or executes the full pipeline for complete query processing.
259
+ """
260
+ try:
261
+ # Fallback to default connection parameters if not provided
262
+ if connection_params is None:
263
+ connection_params = self.default_connection_params.get(database_flavor)
264
+
265
+ if sql_query:
266
+ sql_query = self._clean_sql(sql_query)
267
+
268
+ # Route to specific operations
269
+ if operation == "schema_extract":
270
+ return await self._extract_schema_operation(
271
+ database_flavor, connection_params, schema_names,
272
+ update_knowledge_base, cache_duration_hours
273
+ )
274
+ if operation == "query_generate":
275
+ return await self._query_generation_operation(
276
+ natural_language_query, database_flavor, connection_params, schema_names
277
+ )
278
+ if operation == "query_validate":
279
+ return await self._query_validation_operation(
280
+ sql_query or natural_language_query, database_flavor, connection_params
281
+ )
282
+ if operation == "query_execute":
283
+ return await self._query_execution_operation(
284
+ sql_query, database_flavor, connection_params,
285
+ max_rows, timeout_seconds, output_format, structured_output_schema
286
+ )
287
+ if operation == "full_pipeline":
288
+ return await self._full_pipeline_operation(
289
+ natural_language_query, sql_query, database_flavor, connection_params,
290
+ schema_names, max_rows, timeout_seconds, dry_run,
291
+ output_format, structured_output_schema, update_knowledge_base, cache_duration_hours
292
+ )
293
+ if operation == "explain_query":
294
+ return await self._explain_query_operation(
295
+ sql_query or natural_language_query, database_flavor, connection_params
296
+ )
297
+ else:
298
+ raise ValueError(f"Unknown operation: {operation}")
299
+
300
+ except Exception as e:
301
+ return ToolResult(
302
+ status="error",
303
+ result=None,
304
+ error=f"Database tool execution failed: {str(e)}",
305
+ metadata={
306
+ "operation": operation,
307
+ "database_flavor": database_flavor.value,
308
+ "timestamp": datetime.now(timezone.utc).isoformat()
309
+ }
310
+ )
311
+
312
+ async def _full_pipeline_operation(
313
+ self,
314
+ natural_language_query: Optional[str],
315
+ sql_query: Optional[str],
316
+ database_flavor: DatabaseFlavor,
317
+ connection_params: Optional[Dict[str, Any]],
318
+ schema_names: List[str],
319
+ max_rows: int,
320
+ timeout_seconds: int,
321
+ dry_run: bool,
322
+ output_format: OutputFormat,
323
+ structured_output_schema: Optional[Dict[str, Any]],
324
+ update_knowledge_base: bool,
325
+ cache_duration_hours: int
326
+ ) -> ToolResult:
327
+ """
328
+ Execute the complete database interaction pipeline.
329
+
330
+ This is the main orchestrator method that combines all functionality:
331
+ schema extraction, knowledge base updates, query generation, validation, and execution.
332
+ """
333
+ pipeline_results = {
334
+ "schema_extraction": None,
335
+ "query_generation": None,
336
+ "query_validation": None,
337
+ "query_execution": None,
338
+ "knowledge_base_update": None
339
+ }
340
+
341
+ try:
342
+ # Step 1: Extract and cache schema metadata
343
+ self.logger.info(f"Step 1: Extracting schema for {database_flavor.value}")
344
+ schema_result = await self._extract_schema_operation(
345
+ database_flavor, connection_params, schema_names,
346
+ update_knowledge_base, cache_duration_hours
347
+ )
348
+ pipeline_results["schema_extraction"] = schema_result.result
349
+
350
+ # Step 2: Generate SQL query if natural language was provided
351
+ generated_query = sql_query
352
+ if natural_language_query:
353
+ self.logger.info("Step 2: Generating SQL from natural language")
354
+ query_result = await self._query_generation_operation(
355
+ natural_language_query, database_flavor, connection_params, schema_names
356
+ )
357
+ pipeline_results["query_generation"] = query_result.result
358
+ generated_query = query_result.result.get("sql_query")
359
+
360
+ if not generated_query:
361
+ raise ValueError("No valid SQL query to execute")
362
+
363
+ # Step 3: Validate the query
364
+ self.logger.info("Step 3: Validating SQL query")
365
+ validation_result = await self._query_validation_operation(
366
+ generated_query, database_flavor, connection_params
367
+ )
368
+ pipeline_results["query_validation"] = validation_result.result
369
+
370
+ if not validation_result.result["is_valid"]:
371
+ if dry_run:
372
+ return ToolResult(
373
+ status="success",
374
+ result={
375
+ "pipeline_results": pipeline_results,
376
+ "dry_run": True,
377
+ "query_valid": False
378
+ },
379
+ metadata={"operation": "full_pipeline", "dry_run": True}
380
+ )
381
+ else:
382
+ raise ValueError(f"Query validation failed: {validation_result.result['errors']}")
383
+
384
+ # Step 4: Execute the query (unless dry run)
385
+ if not dry_run:
386
+ self.logger.info("Step 4: Executing validated query")
387
+ execution_result = await self._query_execution_operation(
388
+ generated_query, database_flavor, connection_params,
389
+ max_rows, timeout_seconds, output_format, structured_output_schema
390
+ )
391
+ pipeline_results["query_execution"] = execution_result.result
392
+
393
+ # Success! Return comprehensive results
394
+ return ToolResult(
395
+ status="success",
396
+ result={
397
+ "pipeline_results": pipeline_results,
398
+ "final_query": generated_query,
399
+ "dry_run": dry_run,
400
+ "execution_summary": {
401
+ "rows_returned": len(pipeline_results["query_execution"]["data"]) if not dry_run and pipeline_results["query_execution"] else 0,
402
+ "execution_time_seconds": pipeline_results["query_execution"]["execution_time"] if not dry_run and pipeline_results["query_execution"] else None,
403
+ "output_format": output_format.value
404
+ }
405
+ },
406
+ metadata={
407
+ "operation": "full_pipeline",
408
+ "database_flavor": database_flavor.value,
409
+ "schema_count": len(schema_names),
410
+ "natural_language_input": natural_language_query is not None,
411
+ "timestamp": datetime.utcnow().isoformat()
412
+ }
413
+ )
414
+
415
+ except Exception as e:
416
+ return ToolResult(
417
+ status="error",
418
+ result={"pipeline_results": pipeline_results},
419
+ error=f"Pipeline failed at step: {str(e)}",
420
+ metadata={"operation": "full_pipeline", "partial_results": True}
421
+ )
422
+
423
+ async def _extract_schema_operation(
424
+ self,
425
+ database_flavor: DatabaseFlavor,
426
+ connection_params: Optional[Dict[str, Any]],
427
+ schema_names: List[str],
428
+ update_knowledge_base: bool,
429
+ cache_duration_hours: int
430
+ ) -> ToolResult:
431
+ """Extract database schema metadata and optionally update knowledge base."""
432
+ try:
433
+ # Check cache first
434
+ cache_key = self._generate_schema_cache_key(database_flavor, connection_params, schema_names)
435
+ cached_schema, cache_time = self._schema_cache.get(cache_key, (None, None))
436
+
437
+ if cached_schema and cache_time:
438
+ cache_age = datetime.utcnow() - cache_time
439
+ if cache_age < timedelta(hours=cache_duration_hours):
440
+ self.logger.info(f"Using cached schema metadata (age: {cache_age})")
441
+ return ToolResult(
442
+ status="success",
443
+ result=cached_schema.dict(),
444
+ metadata={"source": "cache", "cache_age_hours": cache_age.total_seconds() / 3600}
445
+ )
446
+
447
+ # Extract fresh schema metadata
448
+ db_connection = await self._get_database_connection(database_flavor, connection_params)
449
+ schema_metadata = await self._extract_database_schema(db_connection, database_flavor, schema_names)
450
+
451
+ # Cache the results
452
+ self._schema_cache[cache_key] = (schema_metadata, datetime.utcnow())
453
+
454
+ # Update knowledge base if requested
455
+ if update_knowledge_base and self.knowledge_store:
456
+ await self._update_schema_knowledge_base(schema_metadata)
457
+
458
+ return ToolResult(
459
+ status="success",
460
+ result=schema_metadata.dict(),
461
+ metadata={
462
+ "source": "database",
463
+ "schema_count": len(schema_names),
464
+ "table_count": len(schema_metadata.tables),
465
+ "view_count": len(schema_metadata.views),
466
+ "knowledge_base_updated": update_knowledge_base and self.knowledge_store is not None
467
+ }
468
+ )
469
+
470
+ except Exception as e:
471
+ return ToolResult(
472
+ status="error",
473
+ result=None,
474
+ error=f"Schema extraction failed: {str(e)}",
475
+ metadata={"operation": "schema_extract"}
476
+ )
477
+
478
+ # Additional helper methods would continue here...
479
+ # Including _query_generation_operation, _query_validation_operation,
480
+ # _query_execution_operation, and all the database-specific implementations
481
+
482
+ def _generate_schema_cache_key(
483
+ self,
484
+ database_flavor: DatabaseFlavor,
485
+ connection_params: Optional[Dict[str, Any]],
486
+ schema_names: List[str]
487
+ ) -> str:
488
+ """Generate a unique cache key for schema metadata."""
489
+ key_data = {
490
+ "flavor": database_flavor.value,
491
+ "params": connection_params or {},
492
+ "schemas": sorted(schema_names)
493
+ }
494
+ return hashlib.md5(json.dumps(key_data, sort_keys=True).encode()).hexdigest()
495
+
496
+ async def _get_database_connection(
497
+ self,
498
+ database_flavor: DatabaseFlavor,
499
+ connection_params: Optional[Dict[str, Any]]
500
+ ) -> AsyncDB:
501
+ """Get or create a database connection using AsyncDB."""
502
+ """Get or create a database connection using AsyncDB."""
503
+ # Normalize connection parameters
504
+ params = connection_params.copy() if connection_params else {}
505
+
506
+ # Common mapping: username -> user (used by asyncpg and others)
507
+ if 'username' in params and 'user' not in params:
508
+ params['user'] = params.pop('username')
509
+
510
+ driver_map = {
511
+ DatabaseFlavor.POSTGRESQL: 'pg',
512
+ DatabaseFlavor.MYSQL: 'mysql',
513
+ DatabaseFlavor.SQLITE: 'sqlite',
514
+ }
515
+ driver = driver_map.get(database_flavor, database_flavor.value)
516
+ return AsyncDB(driver, params=params)
517
+
518
+ async def _extract_database_schema(
519
+ self,
520
+ db_connection: AsyncDB,
521
+ database_flavor: DatabaseFlavor,
522
+ schema_names: List[str]
523
+ ) -> SchemaMetadata:
524
+ """Extract comprehensive schema metadata from the database."""
525
+ """Extract comprehensive schema metadata from the database."""
526
+ if database_flavor == DatabaseFlavor.POSTGRESQL:
527
+ return await self._extract_postgresql_schema(db_connection, schema_names)
528
+
529
+ raise NotImplementedError(f"Schema extraction not implemented for {database_flavor}")
530
+
531
+ async def _extract_postgresql_schema(
532
+ self,
533
+ db: AsyncDB,
534
+ schema_names: List[str]
535
+ ) -> SchemaMetadata:
536
+ """Extract schema for PostgreSQL."""
537
+ tables_data = []
538
+ async with await db.connection() as conn:
539
+ schemas_list = ", ".join([f"'{s}'" for s in schema_names])
540
+ if not schemas_list:
541
+ schemas_list = "'public'" # Default
542
+
543
+ query = f"""
544
+ SELECT t.table_schema, t.table_name, c.column_name, c.data_type
545
+ FROM information_schema.tables t
546
+ JOIN information_schema.columns c
547
+ ON t.table_schema = c.table_schema AND t.table_name = c.table_name
548
+ WHERE t.table_schema IN ({schemas_list})
549
+ ORDER BY t.table_schema, t.table_name, c.ordinal_position
550
+ """
551
+ try:
552
+ rows = await conn.fetch(query) # Using fetch if available, or query
553
+ except Exception:
554
+ # Fallback to query if fetch not available on conn wrapper
555
+ rows = await conn.query(query)
556
+
557
+ # Check if rows is a list of lists (result set wrapper)
558
+ if rows and isinstance(rows, list) and len(rows) > 0 and isinstance(rows[0], list):
559
+ rows = rows[0]
560
+
561
+ # Process rows
562
+ grouped = {}
563
+ for row in rows:
564
+ # Handle possible dict or object access
565
+ # asyncpg.Record supports .get() and ['key']
566
+ if hasattr(row, 'get'):
567
+ s_name = row.get('table_schema')
568
+ t_name = row.get('table_name')
569
+ c_name = row.get('column_name')
570
+ d_type = row.get('data_type')
571
+ elif isinstance(row, (list, tuple)) and len(row) >= 4:
572
+ s_name = row[0]
573
+ t_name = row[1]
574
+ c_name = row[2]
575
+ d_type = row[3]
576
+ else:
577
+ # Attempt dict access as fallback
578
+ try:
579
+ s_name = row['table_schema']
580
+ t_name = row['table_name']
581
+ c_name = row['column_name']
582
+ d_type = row['data_type']
583
+ except (TypeError, KeyError, IndexError):
584
+ continue # Skip invalid rows
585
+
586
+ k = (s_name, t_name)
587
+ if k not in grouped:
588
+ grouped[k] = {
589
+ "schema": s_name,
590
+ "name": t_name,
591
+ "columns": []
592
+ }
593
+ grouped[k]["columns"].append({"name": c_name, "type": d_type})
594
+
595
+ tables_data = list(grouped.values())
596
+
597
+ return SchemaMetadata(
598
+ schema_name=",".join(schema_names),
599
+ tables=tables_data,
600
+ views=[],
601
+ functions=[],
602
+ procedures=[],
603
+ indexes=[],
604
+ constraints=[],
605
+ last_updated=datetime.utcnow(),
606
+ database_flavor=DatabaseFlavor.POSTGRESQL
607
+ )
608
+
609
+ async def _query_generation_operation(
610
+ self,
611
+ natural_language_query: str,
612
+ database_flavor: DatabaseFlavor,
613
+ connection_params: Optional[Dict[str, Any]],
614
+ schema_names: List[str]
615
+ ) -> ToolResult:
616
+ """Generate SQL query from natural language using schema context."""
617
+ try:
618
+ # Get schema context for query generation
619
+ schema_key = self._generate_schema_cache_key(database_flavor, connection_params, schema_names)
620
+ cached_schema, _ = self._schema_cache.get(schema_key, (None, None))
621
+
622
+ if not cached_schema:
623
+ # If no cached schema, extract it first
624
+ schema_result = await self._extract_schema_operation(
625
+ database_flavor, connection_params, schema_names, False, 24
626
+ )
627
+ if schema_result.status != "success" or not schema_result.result:
628
+ raise ValueError(f"Schema extraction failed: {schema_result.error or 'No result returned'}")
629
+
630
+ cached_schema = SchemaMetadata(**schema_result.result)
631
+
632
+ # Use database-specific query generator
633
+ generator = self._query_generators.get(database_flavor)
634
+ if not generator:
635
+ raise ValueError(f"No query generator available for {database_flavor.value}")
636
+
637
+ # Build rich context for LLM query generation
638
+ schema_context = self._build_schema_context_for_llm(cached_schema, natural_language_query)
639
+
640
+ # Generate the SQL query
641
+ generated_sql = await generator(natural_language_query, schema_context)
642
+ generated_sql = self._clean_sql(generated_sql)
643
+
644
+ return ToolResult(
645
+ status="success",
646
+ result={
647
+ "natural_language_query": natural_language_query,
648
+ "sql_query": generated_sql,
649
+ "database_flavor": database_flavor.value,
650
+ "schema_context_used": len(schema_context.get("relevant_tables", [])),
651
+ "generation_timestamp": datetime.utcnow().isoformat()
652
+ },
653
+ metadata={
654
+ "operation": "query_generation",
655
+ "has_schema_context": bool(schema_context)
656
+ }
657
+ )
658
+
659
+ except Exception as e:
660
+ return ToolResult(
661
+ status="error",
662
+ result=None,
663
+ error=f"Query generation failed: {str(e)}",
664
+ metadata={"operation": "query_generation"}
665
+ )
666
+
667
+ async def _query_validation_operation(
668
+ self,
669
+ sql_query: str,
670
+ database_flavor: DatabaseFlavor,
671
+ connection_params: Optional[Dict[str, Any]]
672
+ ) -> ToolResult:
673
+ """Validate SQL query for syntax, security, and performance."""
674
+ try:
675
+ validator = self._query_validators.get(database_flavor)
676
+ if not validator:
677
+ raise ValueError(f"No query validator available for {database_flavor.value}")
678
+
679
+ validation_result = await validator(sql_query)
680
+
681
+ return ToolResult(
682
+ status="success" if validation_result.is_valid else "warning",
683
+ result=validation_result.dict(),
684
+ metadata={
685
+ "operation": "query_validation",
686
+ "query_type": validation_result.query_type.value if validation_result.query_type else None
687
+ }
688
+ )
689
+
690
+ except Exception as e:
691
+ return ToolResult(
692
+ status="error",
693
+ result=None,
694
+ error=f"Query validation failed: {str(e)}",
695
+ metadata={"operation": "query_validation"}
696
+ )
697
+
698
+ async def _query_execution_operation(
699
+ self,
700
+ sql_query: str,
701
+ database_flavor: DatabaseFlavor,
702
+ connection_params: Optional[Dict[str, Any]],
703
+ max_rows: int,
704
+ timeout_seconds: int,
705
+ output_format: OutputFormat,
706
+ structured_output_schema: Optional[Dict[str, Any]]
707
+ ) -> ToolResult:
708
+ """Execute SQL query and format results according to specifications."""
709
+ try:
710
+ db_connection = await self._get_database_connection(database_flavor, connection_params)
711
+
712
+ # Execute query with timeout and row limit
713
+ start_time = datetime.utcnow()
714
+
715
+ # This integrates your existing DatabaseQueryTool logic
716
+ raw_results = await self._execute_query_with_asyncdb(
717
+ db_connection, sql_query, max_rows, timeout_seconds
718
+ )
719
+
720
+ execution_time = (datetime.utcnow() - start_time).total_seconds()
721
+
722
+ # Format results according to specified output format
723
+ formatted_results = await self._format_query_results(
724
+ raw_results, output_format, structured_output_schema
725
+ )
726
+
727
+ return ToolResult(
728
+ status="success",
729
+ result={
730
+ "data": formatted_results,
731
+ "row_count": len(raw_results) if isinstance(raw_results, list) else None,
732
+ "execution_time": execution_time,
733
+ "output_format": output_format.value,
734
+ "query": sql_query
735
+ },
736
+ metadata={
737
+ "operation": "query_execution",
738
+ "database_flavor": database_flavor.value,
739
+ "rows_returned": len(raw_results) if isinstance(raw_results, list) else 0
740
+ }
741
+ )
742
+
743
+ except Exception as e:
744
+ return ToolResult(
745
+ status="error",
746
+ result=None,
747
+ error=f"Query execution failed: {str(e)}",
748
+ metadata={"operation": "query_execution", "query": sql_query}
749
+ )
750
+
751
+ async def _explain_query_operation(
752
+ self,
753
+ sql_query: str,
754
+ database_flavor: DatabaseFlavor,
755
+ connection_params: Optional[Dict[str, Any]]
756
+ ) -> ToolResult:
757
+ """
758
+ Explain query execution plan and provide LLM-based optimizations.
759
+ """
760
+ if not sql_query:
761
+ return ToolResult(
762
+ status="error",
763
+ result=None,
764
+ error="No SQL query provided for explanation",
765
+ metadata={"operation": "explain_query"}
766
+ )
767
+
768
+ try:
769
+ db_connection = await self._get_database_connection(database_flavor, connection_params)
770
+
771
+ # Determine appropriate EXPLAIN command
772
+ explain_cmd = f"EXPLAIN ANALYZE {sql_query}"
773
+ if database_flavor == DatabaseFlavor.MYSQL:
774
+ # MySQL 8.0.18+ supports EXPLAIN ANALYZE, otherwise fallback to EXPLAIN
775
+ # For safety/compatibility we might start with EXPLAIN if ANALYZE fails or just try
776
+ explain_cmd = f"EXPLAIN ANALYZE {sql_query}"
777
+ elif database_flavor == DatabaseFlavor.BIGQUERY:
778
+ # BigQuery doesn't support EXPLAIN ANALYZE syntax directly in this way usually
779
+ # It returns stats in job metadata.
780
+ # But we can try to use Dry Run or similar.
781
+ # For now, let's assume standard SQL syntax applies or let execution fail and fallback
782
+ pass
783
+
784
+ # Execute explanation
785
+ try:
786
+ raw_plan = await self._execute_query_with_asyncdb(
787
+ db_connection, explain_cmd, max_rows=0, timeout_seconds=30
788
+ )
789
+ except Exception as e:
790
+ # Fallback to simple EXPLAIN if ANALYZE fails (e.g. not supported or timeouts)
791
+ self.logger.warning(f"EXPLAIN ANALYZE failed, falling back to EXPLAIN: {e}")
792
+ explain_cmd = f"EXPLAIN {sql_query}"
793
+ raw_plan = await self._execute_query_with_asyncdb(
794
+ db_connection, explain_cmd, max_rows=0, timeout_seconds=30
795
+ )
796
+
797
+ # Format plan into string
798
+ plan_text = ""
799
+ if isinstance(raw_plan, list):
800
+ # Flatten the list of rows/dicts
801
+ for row in raw_plan:
802
+ if isinstance(row, dict):
803
+ # Usually the first column contains the plan output
804
+ plan_text += list(row.values())[0] + "\n"
805
+ elif isinstance(row, (list, tuple)):
806
+ plan_text += str(row[0]) + "\n"
807
+ else:
808
+ plan_text += str(row) + "\n"
809
+ else:
810
+ plan_text = str(raw_plan)
811
+
812
+ # Ask LLM to explain and optimize
813
+ llm_explanation = "No LLM configured for explanation."
814
+ if self.llm:
815
+ prompt = (
816
+ f"You are a database performance expert. Analyze the following query plan for a {database_flavor.value} database.\n"
817
+ f"Query:\n```sql\n{sql_query}\n```\n\n"
818
+ f"Execution Plan:\n```\n{plan_text}\n```\n\n"
819
+ "Please provide:\n"
820
+ "1. A human-readable explanation of how the query is executed.\n"
821
+ "2. Performance bottlenecks identified in the plan.\n"
822
+ "3. Concrete suggestions for indexes or query rewrites to improve performance.\n"
823
+ "4. Rating of current query efficiency (1-10)."
824
+ )
825
+
826
+ response = await self.llm.ask(prompt)
827
+ if isinstance(response, AIMessage):
828
+ llm_explanation = str(response.output).strip()
829
+ elif isinstance(response, dict) and 'content' in response:
830
+ llm_explanation = str(response['content']).strip()
831
+ else:
832
+ llm_explanation = str(response).strip()
833
+
834
+ return ToolResult(
835
+ status="success",
836
+ result={
837
+ "query": sql_query,
838
+ "plan": plan_text,
839
+ "analysis": llm_explanation,
840
+ "database_flavor": database_flavor.value
841
+ },
842
+ metadata={
843
+ "operation": "explain_query",
844
+ "command_used": explain_cmd
845
+ }
846
+ )
847
+
848
+ except Exception as e:
849
+ return ToolResult(
850
+ status="error",
851
+ result=None,
852
+ error=f"Query explanation failed: {str(e)}",
853
+ metadata={"operation": "explain_query"}
854
+ )
855
+
856
+ def _build_schema_context_for_llm(
857
+ self,
858
+ schema_metadata: SchemaMetadata,
859
+ natural_language_query: str
860
+ ) -> Dict[str, Any]:
861
+ """
862
+ Build rich schema context for LLM query generation.
863
+
864
+ This is a critical method that determines query generation quality.
865
+ It intelligently selects relevant schema elements based on the natural language query.
866
+ """
867
+ # Use vector similarity or keyword matching to find relevant tables
868
+ relevant_tables = self._find_relevant_tables(schema_metadata, natural_language_query)
869
+
870
+ # Build comprehensive context including relationships, constraints, and sample data
871
+ context = {
872
+ "database_flavor": schema_metadata.database_flavor.value,
873
+ "schema_name": schema_metadata.schema_name,
874
+ "relevant_tables": relevant_tables,
875
+ "table_relationships": self._extract_table_relationships(schema_metadata, relevant_tables),
876
+ "common_patterns": self._get_query_patterns_for_tables(relevant_tables),
877
+ "data_types_guide": self._get_data_type_guide(schema_metadata.database_flavor)
878
+ }
879
+
880
+ return context
881
+
882
+ async def _execute_query_with_asyncdb(
883
+ self,
884
+ db_connection: AsyncDB,
885
+ sql_query: str,
886
+ max_rows: int,
887
+ timeout_seconds: int
888
+ ) -> Any:
889
+ """Execute query using AsyncDB with proper error handling and limits."""
890
+ # This integrates your existing DatabaseQueryTool execution logic
891
+ # but with enhanced error handling and result limiting
892
+
893
+ try:
894
+ # Add LIMIT clause if not present and max_rows is specified
895
+ if max_rows > 0 and "LIMIT" not in sql_query.upper():
896
+ sql_query = f"{sql_query.rstrip(';')} LIMIT {max_rows};"
897
+
898
+ # Execute with timeout using asyncio
899
+ async with await db_connection.connection() as conn:
900
+ return await asyncio.wait_for(
901
+ conn.fetchall(sql_query),
902
+ timeout=timeout_seconds
903
+ )
904
+
905
+ except asyncio.TimeoutError as e:
906
+ raise TimeoutError(
907
+ f"Query execution timed out after {timeout_seconds} seconds"
908
+ ) from e
909
+ except Exception as e:
910
+ raise RuntimeError(
911
+ f"Database execution error: {str(e)}"
912
+ ) from e
913
+
914
+ async def _format_query_results(
915
+ self,
916
+ raw_results: Any,
917
+ output_format: OutputFormat,
918
+ structured_output_schema: Optional[Dict[str, Any]]
919
+ ) -> Any:
920
+ """Format query results according to specified output format."""
921
+ if output_format == OutputFormat.PANDAS:
922
+ return pd.DataFrame(raw_results) if raw_results else pd.DataFrame()
923
+ elif output_format == OutputFormat.JSON:
924
+ return json.dumps(raw_results, default=str, indent=2)
925
+ elif output_format == OutputFormat.DICT:
926
+ return raw_results
927
+ elif output_format == OutputFormat.CSV:
928
+ df = pd.DataFrame(raw_results) if raw_results else pd.DataFrame()
929
+ return df.to_csv(index=False)
930
+ elif output_format == OutputFormat.STRUCTURED and structured_output_schema:
931
+ # Convert results to Pydantic models based on provided schema
932
+ return self._convert_to_structured_output(raw_results, structured_output_schema)
933
+ else:
934
+ return raw_results
935
+
936
+ # Database-specific implementations (these would replace your current separate tools)
937
+ async def _generate_postgresql_query(self, natural_language: str, schema_context: Dict) -> str:
938
+ """
939
+ Generate PostgreSQL-specific SQL from natural language.
940
+
941
+ This method would integrate your existing SQLAgent logic but with enhanced
942
+ schema context and PostgreSQL-specific optimizations.
943
+ """
944
+ # Build prompt with rich schema context
945
+ prompt = self._build_query_generation_prompt(
946
+ natural_language, schema_context, "postgresql"
947
+ )
948
+
949
+ # Use your existing LLM client to generate the query
950
+ # This would integrate with your AI-Parrot LLM clients
951
+ return await self._call_llm_for_query_generation(prompt)
952
+
953
+ async def _validate_postgresql_query(self, query: str) -> QueryValidationResult:
954
+ """
955
+ Validate PostgreSQL query for syntax, security, and performance.
956
+
957
+ This provides the validation layer that was missing from your current SQLAgent.
958
+ """
959
+ validation_result = QueryValidationResult(
960
+ is_valid=True,
961
+ query_type=None,
962
+ affected_tables=[],
963
+ estimated_cost=None,
964
+ warnings=[],
965
+ errors=[],
966
+ security_checks={}
967
+ )
968
+
969
+ try:
970
+ # Parse query to determine type and affected tables
971
+ query_upper = query.strip().upper()
972
+ if query_upper.startswith('SELECT'):
973
+ validation_result.query_type = QueryType.SELECT
974
+ elif query_upper.startswith('INSERT'):
975
+ validation_result.query_type = QueryType.INSERT
976
+ # ... other query types
977
+
978
+ # Security checks
979
+ validation_result.security_checks = {
980
+ "no_sql_injection_patterns": self._check_sql_injection_patterns(query),
981
+ "no_dangerous_operations": self._check_dangerous_operations(query),
982
+ "proper_quoting": self._check_proper_quoting(query)
983
+ }
984
+
985
+ # Syntax validation (could use sqlparse or connect to database for EXPLAIN)
986
+ syntax_valid = await self._validate_syntax_postgresql(query)
987
+ if not syntax_valid:
988
+ validation_result.is_valid = False
989
+ validation_result.errors.append("Invalid SQL syntax")
990
+
991
+ # Performance warnings
992
+ if "SELECT *" in query_upper:
993
+ validation_result.warnings.append("Consider specifying explicit columns instead of SELECT *")
994
+
995
+ return validation_result
996
+
997
+ except Exception as e:
998
+ validation_result.is_valid = False
999
+ validation_result.errors.append(f"Validation error: {str(e)}")
1000
+ return validation_result
1001
+
1002
+ def _check_dangerous_operations(self, query: str) -> bool:
1003
+ """
1004
+ Check if query contains dangerous operations that should be blocked.
1005
+
1006
+ Returns:
1007
+ True if query is SAFE (no dangerous operations)
1008
+ False if dangerous operations detected
1009
+ """
1010
+ query_upper = query.upper()
1011
+
1012
+ dangerous_patterns = [
1013
+ # DDL operations
1014
+ r'\bDROP\s+(TABLE|DATABASE|SCHEMA|INDEX|VIEW|PROCEDURE|FUNCTION)\b',
1015
+ r'\bTRUNCATE\s+TABLE\b',
1016
+ r'\bALTER\s+(TABLE|DATABASE|SCHEMA)\s+.*\s+DROP\b',
1017
+ # DML without WHERE
1018
+ r'\bDELETE\s+FROM\s+\w+\s*;?\s*$',
1019
+ # Admin commands
1020
+ r'\bGRANT\b',
1021
+ r'\bREVOKE\b',
1022
+ r'\bCREATE\s+USER\b',
1023
+ r'\bDROP\s+USER\b',
1024
+ r'\bALTER\s+USER\b',
1025
+ # Command execution (PostgreSQL)
1026
+ r'\bCOPY\s+.*\s+TO\s+PROGRAM\b',
1027
+ # SQL Server
1028
+ r'\bEXEC\s*\(',
1029
+ r'\bXP_CMDSHELL\b',
1030
+ # MySQL file operations
1031
+ r'\bLOAD_FILE\b',
1032
+ r'\bINTO\s+OUTFILE\b',
1033
+ r'\bINTO\s+DUMPFILE\b',
1034
+ ]
1035
+
1036
+ for pattern in dangerous_patterns:
1037
+ if re.search(pattern, query_upper, re.IGNORECASE | re.DOTALL):
1038
+ return False
1039
+
1040
+ # Check DELETE/UPDATE without WHERE
1041
+ if re.search(r'\bDELETE\s+FROM\s+\w+\s*$', query_upper):
1042
+ return False
1043
+
1044
+ update_match = re.search(r'\bUPDATE\s+\w+\s+SET\s+', query_upper)
1045
+ if update_match and 'WHERE' not in query_upper:
1046
+ return False
1047
+
1048
+ return True
1049
+
1050
+ def _check_sql_injection_patterns(self, query: str) -> bool:
1051
+ """
1052
+ Check for common SQL injection patterns.
1053
+
1054
+ Returns:
1055
+ True if no injection patterns found (SAFE)
1056
+ False if potential injection detected
1057
+ """
1058
+ injection_patterns = [
1059
+ # Union/Boolean-based
1060
+ r"'\s*(OR|AND)\s+['\"0-9]",
1061
+ r"'\s*OR\s+1\s*=\s*1",
1062
+ r"'\s*OR\s+'[^']*'\s*=\s*'[^']*'",
1063
+ # Comment-based
1064
+ r";\s*--",
1065
+ r";\s*/\*",
1066
+ r"--\s*$",
1067
+ # Stacked queries
1068
+ r"'\s*;\s*(DROP|DELETE|UPDATE|INSERT|EXEC)\b",
1069
+ r";\s*(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER)\b",
1070
+ # UNION injection
1071
+ r"\bUNION\s+(ALL\s+)?SELECT\b.*\bFROM\b",
1072
+ # Time-based
1073
+ r"\bSLEEP\s*\(",
1074
+ r"\bWAITFOR\s+DELAY\b",
1075
+ r"\bBENCHMARK\s*\(",
1076
+ r"\bPG_SLEEP\s*\(",
1077
+ # Encoding attempts
1078
+ r"0x[0-9a-fA-F]+",
1079
+ r"\bCHAR\s*\(\s*\d+\s*\)",
1080
+ ]
1081
+
1082
+ for pattern in injection_patterns:
1083
+ if re.search(pattern, query, re.IGNORECASE):
1084
+ return False
1085
+
1086
+ return True
1087
+
1088
+ def _check_proper_quoting(self, query: str) -> bool:
1089
+ """
1090
+ Check if string literals are properly quoted.
1091
+
1092
+ Returns:
1093
+ True if quoting appears proper (SAFE)
1094
+ False if improper quoting detected
1095
+ """
1096
+ # Check for unbalanced quotes
1097
+ single_quotes = query.count("'") - query.count("\\'") - query.count("''")
1098
+ double_quotes = query.count('"') - query.count('\\"') - query.count('""')
1099
+
1100
+ if single_quotes % 2 != 0 or double_quotes % 2 != 0:
1101
+ return False
1102
+
1103
+ # Check for dangerous patterns after string literals
1104
+ dangerous_patterns = [
1105
+ r"'\s*\)\s*(OR|AND|UNION)\b",
1106
+ r"'\s*;\s*\w+",
1107
+ ]
1108
+
1109
+ for pattern in dangerous_patterns:
1110
+ if re.search(pattern, query, re.IGNORECASE):
1111
+ return False
1112
+
1113
+ return True
1114
+
1115
+ async def _validate_syntax_postgresql(self, query: str) -> bool:
1116
+ """Validate PostgreSQL query syntax using pattern matching."""
1117
+ try:
1118
+ query_stripped = query.strip().rstrip(';')
1119
+ query_upper = query_stripped.upper()
1120
+
1121
+ valid_starts = [
1122
+ 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'WITH',
1123
+ 'CREATE', 'ALTER', 'DROP', 'TRUNCATE',
1124
+ 'BEGIN', 'COMMIT', 'ROLLBACK', 'SAVEPOINT',
1125
+ 'EXPLAIN', 'ANALYZE', 'VACUUM', 'REINDEX',
1126
+ 'GRANT', 'REVOKE', 'SET', 'SHOW', 'RESET',
1127
+ 'COPY', 'CALL', 'DO', 'LOCK',
1128
+ ]
1129
+
1130
+ first_word = query_upper.split()[0] if query_stripped else ''
1131
+ if first_word not in valid_starts:
1132
+ return False
1133
+
1134
+ # Validate statement structures
1135
+ if query_upper.startswith('INSERT') and 'INTO' not in query_upper:
1136
+ return False
1137
+ if query_upper.startswith('UPDATE') and 'SET' not in query_upper:
1138
+ return False
1139
+ if query_upper.startswith('DELETE') and 'FROM' not in query_upper:
1140
+ return False
1141
+
1142
+ # Check balance
1143
+ if query.count('(') != query.count(')'):
1144
+ return False
1145
+ if query.count('[') != query.count(']'):
1146
+ return False
1147
+
1148
+ return True
1149
+ except Exception:
1150
+ return False
1151
+
1152
+ # =========================================================================
1153
+ # MYSQL METHODS
1154
+ # =========================================================================
1155
+
1156
+ async def _generate_mysql_query(
1157
+ self,
1158
+ natural_language: str,
1159
+ schema_context: Dict[str, Any]
1160
+ ) -> str:
1161
+ """Generate MySQL-specific SQL from natural language."""
1162
+ prompt = self._build_query_generation_prompt(
1163
+ natural_language, schema_context, "mysql"
1164
+ )
1165
+
1166
+ mysql_instructions = """
1167
+ MySQL-Specific Rules:
1168
+ 1. Use backticks (`) for identifier quoting
1169
+ 2. Use LIMIT for row limiting
1170
+ 3. Use IFNULL() instead of COALESCE() for two arguments
1171
+ 4. Use NOW() for current timestamp
1172
+ 5. Use DATE_FORMAT() for date formatting
1173
+ 6. Use CONCAT() for string concatenation
1174
+ 7. Boolean values are 1/0
1175
+ 8. Use REGEXP for regex matching
1176
+ """
1177
+
1178
+ generated_query = await self._call_llm_for_query_generation(
1179
+ f"{prompt}\n\n{mysql_instructions}"
1180
+ )
1181
+
1182
+ return self._ensure_mysql_compatibility(generated_query)
1183
+
1184
+ def _ensure_mysql_compatibility(self, query: str) -> str:
1185
+ """Post-process query to ensure MySQL compatibility."""
1186
+ result = query
1187
+
1188
+ # Replace double quotes with backticks for identifiers
1189
+ result = re.sub(r'"(\w+)"(?=\s*[,.\)\s]|$)', r'`\1`', result)
1190
+
1191
+ # Replace COALESCE with two args to IFNULL
1192
+ result = re.sub(
1193
+ r'\bCOALESCE\s*\(\s*([^,]+)\s*,\s*([^,\)]+)\s*\)',
1194
+ r'IFNULL(\1, \2)',
1195
+ result,
1196
+ flags=re.IGNORECASE
1197
+ )
1198
+
1199
+ # Replace TRUE/FALSE with 1/0
1200
+ result = re.sub(r'\bTRUE\b', '1', result, flags=re.IGNORECASE)
1201
+ result = re.sub(r'\bFALSE\b', '0', result, flags=re.IGNORECASE)
1202
+
1203
+ return result
1204
+
1205
+ async def _validate_mysql_query(self, query: str) -> QueryValidationResult:
1206
+ """Validate MySQL query for syntax, security, and performance."""
1207
+ validation_result = QueryValidationResult(
1208
+ is_valid=True,
1209
+ query_type=None,
1210
+ affected_tables=[],
1211
+ estimated_cost=None,
1212
+ warnings=[],
1213
+ errors=[],
1214
+ security_checks={}
1215
+ )
1216
+
1217
+ try:
1218
+ query_upper = query.strip().upper()
1219
+
1220
+ # Determine query type
1221
+ for qt in QueryType:
1222
+ if query_upper.startswith(qt.value):
1223
+ validation_result.query_type = qt
1224
+ break
1225
+
1226
+ # Extract tables
1227
+ validation_result.affected_tables = self._extract_tables_from_query(query)
1228
+
1229
+ # Security checks
1230
+ validation_result.security_checks = {
1231
+ "no_sql_injection_patterns": self._check_sql_injection_patterns(query),
1232
+ "no_dangerous_operations": self._check_dangerous_operations(query),
1233
+ "proper_quoting": self._check_proper_quoting(query)
1234
+ }
1235
+
1236
+ if not all(validation_result.security_checks.values()):
1237
+ validation_result.is_valid = False
1238
+ for check, passed in validation_result.security_checks.items():
1239
+ if not passed:
1240
+ validation_result.errors.append(f"Security check failed: {check}")
1241
+
1242
+ # Syntax validation
1243
+ if not await self._validate_syntax_mysql(query):
1244
+ validation_result.is_valid = False
1245
+ validation_result.errors.append("Invalid MySQL syntax")
1246
+
1247
+ # Performance warnings
1248
+ if "SELECT *" in query_upper:
1249
+ validation_result.warnings.append("Consider specifying explicit columns")
1250
+
1251
+ if re.search(r"LIKE\s*'%", query, re.IGNORECASE):
1252
+ validation_result.warnings.append("Leading wildcard may prevent index usage")
1253
+
1254
+ return validation_result
1255
+
1256
+ except Exception as e:
1257
+ validation_result.is_valid = False
1258
+ validation_result.errors.append(f"Validation error: {str(e)}")
1259
+ return validation_result
1260
+
1261
+ async def _validate_syntax_mysql(self, query: str) -> bool:
1262
+ """Validate MySQL-specific query syntax."""
1263
+ try:
1264
+ query_stripped = query.strip().rstrip(';')
1265
+ query_upper = query_stripped.upper()
1266
+
1267
+ valid_starts = [
1268
+ 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'REPLACE',
1269
+ 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME',
1270
+ 'START', 'BEGIN', 'COMMIT', 'ROLLBACK', 'SAVEPOINT',
1271
+ 'SET', 'SHOW', 'DESCRIBE', 'DESC', 'EXPLAIN',
1272
+ 'GRANT', 'REVOKE', 'LOCK', 'UNLOCK', 'USE', 'WITH',
1273
+ ]
1274
+
1275
+ first_word = query_upper.split()[0] if query_stripped else ''
1276
+ if first_word not in valid_starts:
1277
+ return False
1278
+
1279
+ # Validate statement structures
1280
+ if first_word in ('INSERT', 'REPLACE') and 'INTO' not in query_upper:
1281
+ return False
1282
+ if first_word == 'UPDATE' and 'SET' not in query_upper:
1283
+ return False
1284
+
1285
+ # Check balance
1286
+ if query.count('(') != query.count(')'):
1287
+ return False
1288
+ if query.count('`') % 2 != 0:
1289
+ return False
1290
+
1291
+ return True
1292
+ except Exception:
1293
+ return False
1294
+
1295
+ # =========================================================================
1296
+ # BIGQUERY METHODS
1297
+ # =========================================================================
1298
+
1299
+ async def _generate_bigquery_query(
1300
+ self,
1301
+ natural_language: str,
1302
+ schema_context: Dict[str, Any]
1303
+ ) -> str:
1304
+ """Generate BigQuery-specific SQL from natural language."""
1305
+ prompt = self._build_query_generation_prompt(
1306
+ natural_language, schema_context, "bigquery"
1307
+ )
1308
+
1309
+ bigquery_instructions = """
1310
+ BigQuery-Specific Rules:
1311
+ 1. Use backticks for table names: `project.dataset.table`
1312
+ 2. Use STRUCT<> and ARRAY<> for complex types
1313
+ 3. Use UNNEST() to flatten arrays
1314
+ 4. Use SAFE_DIVIDE() for division with potential zeros
1315
+ 5. Use FORMAT_DATE/FORMAT_TIMESTAMP for date formatting
1316
+ 6. Use DATE_DIFF, TIMESTAMP_DIFF for date differences
1317
+ 7. Use QUALIFY clause for window function filtering
1318
+ 8. Use Standard SQL (prefix with #standardSQL if needed)
1319
+ """
1320
+
1321
+ generated_query = await self._call_llm_for_query_generation(
1322
+ f"{prompt}\n\n{bigquery_instructions}"
1323
+ )
1324
+
1325
+ return self._ensure_bigquery_compatibility(generated_query)
1326
+
1327
+ def _ensure_bigquery_compatibility(self, query: str) -> str:
1328
+ """Post-process query to ensure BigQuery compatibility."""
1329
+ result = query
1330
+
1331
+ # Quote project.dataset.table names
1332
+ result = re.sub(
1333
+ r'(?<![`\w])(\w+)\.(\w+)\.(\w+)(?![`\w])',
1334
+ r'`\1.\2.\3`',
1335
+ result
1336
+ )
1337
+
1338
+ # Replace NOW() with CURRENT_TIMESTAMP()
1339
+ result = re.sub(r'\bNOW\s*\(\s*\)', 'CURRENT_TIMESTAMP()', result, flags=re.IGNORECASE)
1340
+
1341
+ # Replace GETDATE()
1342
+ result = re.sub(r'\bGETDATE\s*\(\s*\)', 'CURRENT_TIMESTAMP()', result, flags=re.IGNORECASE)
1343
+
1344
+ return result
1345
+
1346
+ async def _validate_bigquery_query(self, query: str) -> QueryValidationResult:
1347
+ """Validate BigQuery query for syntax, security, and performance."""
1348
+ validation_result = QueryValidationResult(
1349
+ is_valid=True,
1350
+ query_type=None,
1351
+ affected_tables=[],
1352
+ estimated_cost=None,
1353
+ warnings=[],
1354
+ errors=[],
1355
+ security_checks={}
1356
+ )
1357
+
1358
+ try:
1359
+ query_upper = query.strip().upper()
1360
+
1361
+ # Remove SQL dialect prefix
1362
+ if query_upper.startswith('#'):
1363
+ newline_idx = query.find('\n')
1364
+ if newline_idx > 0:
1365
+ query_upper = query[newline_idx:].strip().upper()
1366
+
1367
+ # Determine query type
1368
+ if query_upper.startswith(('SELECT', 'WITH')):
1369
+ validation_result.query_type = QueryType.SELECT
1370
+ elif query_upper.startswith('MERGE'):
1371
+ validation_result.query_type = QueryType.UPDATE
1372
+ else:
1373
+ for qt in QueryType:
1374
+ if query_upper.startswith(qt.value):
1375
+ validation_result.query_type = qt
1376
+ break
1377
+
1378
+ # Extract tables
1379
+ validation_result.affected_tables = self._extract_bigquery_tables(query)
1380
+
1381
+ # Security checks
1382
+ validation_result.security_checks = {
1383
+ "no_sql_injection_patterns": self._check_sql_injection_patterns(query),
1384
+ "no_dangerous_operations": self._check_dangerous_bigquery_operations(query),
1385
+ "proper_quoting": self._check_proper_quoting(query)
1386
+ }
1387
+
1388
+ if not all(validation_result.security_checks.values()):
1389
+ validation_result.is_valid = False
1390
+ for check, passed in validation_result.security_checks.items():
1391
+ if not passed:
1392
+ validation_result.errors.append(f"Security check failed: {check}")
1393
+
1394
+ # Syntax validation
1395
+ if not await self._validate_syntax_bigquery(query):
1396
+ validation_result.is_valid = False
1397
+ validation_result.errors.append("Invalid BigQuery SQL syntax")
1398
+
1399
+ # Performance warnings
1400
+ if "SELECT *" in query_upper:
1401
+ validation_result.warnings.append(
1402
+ "SELECT * scans all columns - specify needed columns for cost reduction"
1403
+ )
1404
+
1405
+ if 'WHERE' not in query_upper and '_PARTITIONTIME' not in query_upper:
1406
+ validation_result.warnings.append(
1407
+ "Consider adding partition filter for cost reduction"
1408
+ )
1409
+
1410
+ if query.strip().startswith('#legacySQL'):
1411
+ validation_result.warnings.append("Consider migrating to Standard SQL")
1412
+
1413
+ return validation_result
1414
+
1415
+ except Exception as e:
1416
+ validation_result.is_valid = False
1417
+ validation_result.errors.append(f"Validation error: {str(e)}")
1418
+ return validation_result
1419
+
1420
+ def _check_dangerous_bigquery_operations(self, query: str) -> bool:
1421
+ """Check for dangerous BigQuery operations. Returns True if SAFE."""
1422
+ query_upper = query.upper()
1423
+
1424
+ dangerous_patterns = [
1425
+ r'\bDROP\s+(TABLE|SCHEMA|VIEW|FUNCTION)\b',
1426
+ r'\bTRUNCATE\s+TABLE\b',
1427
+ r'\bDELETE\s+FROM\s+`[^`]+`\s*$',
1428
+ r'\bDROP\s+ALL\s+ROW\s+ACCESS\s+POLICIES\b',
1429
+ ]
1430
+
1431
+ for pattern in dangerous_patterns:
1432
+ if re.search(pattern, query_upper, re.IGNORECASE | re.DOTALL):
1433
+ return False
1434
+
1435
+ return True
1436
+
1437
+ def _extract_bigquery_tables(self, query: str) -> List[str]:
1438
+ """Extract table names from BigQuery query."""
1439
+ tables = set()
1440
+
1441
+ # Backtick-quoted fully-qualified names
1442
+ tables.update(re.findall(r'`([a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)`', query))
1443
+ tables.update(re.findall(r'`([a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)`', query))
1444
+
1445
+ # Standard table references
1446
+ tables.update(self._extract_tables_from_query(query))
1447
+
1448
+ return list(tables)
1449
+
1450
+ async def _validate_syntax_bigquery(self, query: str) -> bool:
1451
+ """Validate BigQuery-specific query syntax."""
1452
+ try:
1453
+ query_stripped = query.strip()
1454
+
1455
+ if query_stripped.startswith('#'):
1456
+ newline_idx = query_stripped.find('\n')
1457
+ if newline_idx > 0:
1458
+ query_stripped = query_stripped[newline_idx:].strip()
1459
+
1460
+ query_stripped = query_stripped.rstrip(';')
1461
+ query_upper = query_stripped.upper()
1462
+
1463
+ valid_starts = [
1464
+ 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE',
1465
+ 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'WITH',
1466
+ 'DECLARE', 'SET', 'EXECUTE', 'BEGIN', 'IF',
1467
+ 'EXPORT', 'LOAD', 'GRANT', 'REVOKE', 'ASSERT',
1468
+ ]
1469
+
1470
+ first_word = query_upper.split()[0] if query_stripped else ''
1471
+ if first_word not in valid_starts:
1472
+ return False
1473
+
1474
+ # Check balance
1475
+ if query.count('(') != query.count(')'):
1476
+ return False
1477
+ if query.count('`') % 2 != 0:
1478
+ return False
1479
+ if query.count('[') != query.count(']'):
1480
+ return False
1481
+ if query.count('<') != query.count('>'):
1482
+ return False
1483
+
1484
+ return True
1485
+ except Exception:
1486
+ return False
1487
+
1488
+ # =========================================================================
1489
+ # HELPER METHODS
1490
+ # =========================================================================
1491
+
1492
+ def _extract_tables_from_query(self, query: str) -> List[str]:
1493
+ """Extract table names from SQL query."""
1494
+ tables = set()
1495
+
1496
+ patterns = [
1497
+ r'\bFROM\s+([`"\[]?[\w.-]+[`"\]]?)',
1498
+ r'\bJOIN\s+([`"\[]?[\w.-]+[`"\]]?)',
1499
+ r'\bINSERT\s+INTO\s+([`"\[]?[\w.-]+[`"\]]?)',
1500
+ r'\bUPDATE\s+([`"\[]?[\w.-]+[`"\]]?)',
1501
+ r'\bDELETE\s+FROM\s+([`"\[]?[\w.-]+[`"\]]?)',
1502
+ ]
1503
+
1504
+ for pattern in patterns:
1505
+ for match in re.findall(pattern, query, re.IGNORECASE):
1506
+ tables.add(match.strip('`"[]'))
1507
+
1508
+ return list(tables)
1509
+
1510
+ def _find_relevant_tables(
1511
+ self,
1512
+ schema_metadata: SchemaMetadata,
1513
+ natural_language_query: str
1514
+ ) -> List[Dict[str, Any]]:
1515
+ """Find tables relevant to the natural language query."""
1516
+ relevant_tables = []
1517
+ query_lower = natural_language_query.lower()
1518
+
1519
+ keywords = set(re.findall(r'\b\w+\b', query_lower))
1520
+
1521
+ stop_words = {
1522
+ 'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
1523
+ 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
1524
+ 'could', 'should', 'may', 'might', 'must', 'shall', 'can',
1525
+ 'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
1526
+ 'and', 'or', 'but', 'if', 'as', 'show', 'get', 'find', 'list',
1527
+ 'give', 'tell', 'me', 'i', 'you', 'we', 'they', 'select', 'all',
1528
+ }
1529
+
1530
+ keywords = keywords - stop_words
1531
+
1532
+ for table in schema_metadata.tables:
1533
+ table_name = table.get('table_name', '').lower()
1534
+ columns = table.get('columns', [])
1535
+
1536
+ score = 0
1537
+ matched_columns = []
1538
+
1539
+ # Check table name
1540
+ table_words = set(re.findall(r'\w+', table_name))
1541
+ if table_words & keywords:
1542
+ score += 10
1543
+
1544
+ for keyword in keywords:
1545
+ if keyword in table_name:
1546
+ score += 5
1547
+
1548
+ # Check columns
1549
+ for column in columns:
1550
+ col_name = column.get('column_name', '').lower()
1551
+ col_words = set(re.findall(r'\w+', col_name))
1552
+
1553
+ if col_words & keywords:
1554
+ score += 3
1555
+ matched_columns.append(col_name)
1556
+
1557
+ for keyword in keywords:
1558
+ if keyword in col_name and col_name not in matched_columns:
1559
+ score += 1
1560
+ matched_columns.append(col_name)
1561
+
1562
+ if score > 0:
1563
+ relevant_tables.append({
1564
+ 'table_name': table.get('table_name'),
1565
+ 'schema': table.get('schema', schema_metadata.schema_name),
1566
+ 'columns': columns,
1567
+ 'matched_columns': matched_columns,
1568
+ 'relevance_score': score,
1569
+ 'comment': table.get('comment', '')
1570
+ })
1571
+
1572
+ relevant_tables.sort(key=lambda x: x['relevance_score'], reverse=True)
1573
+ return relevant_tables[:10]
1574
+
1575
+ def _extract_table_relationships(
1576
+ self,
1577
+ schema_metadata: SchemaMetadata,
1578
+ relevant_tables: List[Dict[str, Any]]
1579
+ ) -> List[Dict[str, Any]]:
1580
+ """Extract relationships between relevant tables."""
1581
+ relationships = []
1582
+ relevant_table_names = {t['table_name'] for t in relevant_tables}
1583
+
1584
+ # From constraints
1585
+ for constraint in schema_metadata.constraints:
1586
+ if constraint.get('constraint_type') == 'FOREIGN KEY':
1587
+ source_table = constraint.get('table_name')
1588
+ target_table = constraint.get('referenced_table')
1589
+
1590
+ if source_table in relevant_table_names or target_table in relevant_table_names:
1591
+ relationships.append({
1592
+ 'type': 'foreign_key',
1593
+ 'source_table': source_table,
1594
+ 'source_column': constraint.get('column_name'),
1595
+ 'target_table': target_table,
1596
+ 'target_column': constraint.get('referenced_column'),
1597
+ })
1598
+
1599
+ # Infer from naming conventions
1600
+ for table in relevant_tables:
1601
+ for column in table.get('columns', []):
1602
+ col_name = column.get('column_name', '')
1603
+
1604
+ if col_name.endswith('_id'):
1605
+ potential_table = col_name[:-3]
1606
+ for pt in [potential_table, potential_table + 's', potential_table + 'es']:
1607
+ if pt in relevant_table_names:
1608
+ relationships.append({
1609
+ 'type': 'inferred',
1610
+ 'source_table': table['table_name'],
1611
+ 'source_column': col_name,
1612
+ 'target_table': pt,
1613
+ 'target_column': 'id',
1614
+ })
1615
+ break
1616
+
1617
+ return relationships
1618
+
1619
+ def _get_query_patterns_for_tables(
1620
+ self,
1621
+ relevant_tables: List[Dict[str, Any]]
1622
+ ) -> List[Dict[str, Any]]:
1623
+ """Generate common query patterns for relevant tables."""
1624
+ patterns = []
1625
+
1626
+ for table in relevant_tables[:3]:
1627
+ table_name = table['table_name']
1628
+ columns = table.get('columns', [])
1629
+
1630
+ if columns:
1631
+ col_list = ', '.join([c['column_name'] for c in columns[:5]])
1632
+ patterns.append({
1633
+ 'description': f'Select from {table_name}',
1634
+ 'pattern': f'SELECT {col_list} FROM {table_name} WHERE ...',
1635
+ })
1636
+
1637
+ if numeric_cols := [
1638
+ c for c in columns if c.get('data_type', '').lower() in (
1639
+ 'integer', 'int', 'bigint', 'numeric', 'decimal', 'float'
1640
+ )
1641
+ ]:
1642
+ patterns.append({
1643
+ 'description': f'Aggregate {table_name}',
1644
+ 'pattern': f'SELECT COUNT(*), SUM({numeric_cols[0]["column_name"]}) FROM {table_name} GROUP BY ...',
1645
+ })
1646
+
1647
+ return patterns
1648
+
1649
+ def _get_data_type_guide(self, database_flavor: DatabaseFlavor) -> Dict[str, Any]:
1650
+ """Get data type information for database flavor."""
1651
+ guides = {
1652
+ DatabaseFlavor.POSTGRESQL: {
1653
+ 'string_concat': '|| operator or CONCAT()',
1654
+ 'null_handling': 'IS NULL / IS NOT NULL, COALESCE()',
1655
+ 'boolean_type': 'BOOLEAN',
1656
+ },
1657
+ DatabaseFlavor.MYSQL: {
1658
+ 'string_concat': 'CONCAT() function',
1659
+ 'null_handling': 'IS NULL / IS NOT NULL, IFNULL()',
1660
+ 'boolean_type': 'TINYINT(1)',
1661
+ },
1662
+ DatabaseFlavor.BIGQUERY: {
1663
+ 'string_concat': 'CONCAT() or ||',
1664
+ 'null_handling': 'IS NULL / IS NOT NULL, IFNULL(), COALESCE()',
1665
+ 'boolean_type': 'BOOL',
1666
+ }
1667
+ }
1668
+ return guides.get(database_flavor, guides[DatabaseFlavor.POSTGRESQL])
1669
+
1670
+ def _build_query_generation_prompt(
1671
+ self,
1672
+ natural_language: str,
1673
+ schema_context: Dict[str, Any],
1674
+ dialect: str
1675
+ ) -> str:
1676
+ """Build prompt for LLM query generation."""
1677
+ prompt_parts = [
1678
+ f"Generate a {dialect.upper()} SQL query for:",
1679
+ f"\nRequest: {natural_language}",
1680
+ f"\nDatabase: {schema_context.get('database_flavor', dialect).upper()}",
1681
+ "\n\nAvailable Tables:",
1682
+ ]
1683
+
1684
+ for table in schema_context.get('relevant_tables', [])[:5]:
1685
+ prompt_parts.append(f"\n\nTable: {table.get('table_name')}")
1686
+ columns = table.get('columns', [])[:15]
1687
+ if columns:
1688
+ prompt_parts.append("\nColumns:")
1689
+ for col in columns:
1690
+ col_info = f" - {col.get('column_name')}: {col.get('data_type', 'unknown')}"
1691
+ prompt_parts.append(col_info)
1692
+
1693
+ relationships = schema_context.get('table_relationships', [])
1694
+ if relationships:
1695
+ prompt_parts.append("\n\nRelationships:")
1696
+ for rel in relationships[:5]:
1697
+ prompt_parts.append(
1698
+ f" - {rel['source_table']}.{rel['source_column']} -> "
1699
+ f"{rel['target_table']}.{rel['target_column']}"
1700
+ )
1701
+
1702
+ prompt_parts.append("\n\nGenerate only the SQL query, no explanations:")
1703
+ return '\n'.join(prompt_parts)
1704
+
1705
+ async def _call_llm_for_query_generation(self, prompt: str) -> str:
1706
+ """Call LLM client to generate SQL query."""
1707
+ system_msg = "You are a SQL expert. Generate precise SQL queries. Return only the SQL, no explanations."
1708
+
1709
+ if self.llm:
1710
+ response = await self.llm.ask(prompt, system_prompt=system_msg)
1711
+ if isinstance(response, AIMessage):
1712
+ return str(response.output).strip()
1713
+ # Handle possible dict response if client doesn't return AIMessage (fallback)
1714
+ elif isinstance(response, dict) and 'content' in response:
1715
+ # Should extract text
1716
+ return str(response['content']).strip() # Simplified fallback
1717
+ return str(response).strip()
1718
+
1719
+ if hasattr(self, 'agent') and self.agent:
1720
+ response = await self.agent.llm.acomplete(prompt, system_message=system_msg)
1721
+ return response.strip()
1722
+
1723
+ if hasattr(self, 'llm_client') and self.llm_client:
1724
+ response = await self.llm_client.acomplete(prompt)
1725
+ return response.strip()
1726
+
1727
+ raise ValueError(
1728
+ "No LLM client configured. Provide an 'llm', 'agent' or 'llm_client' to DatabaseTool."
1729
+ )
1730
+
1731
+ async def _update_schema_knowledge_base(
1732
+ self,
1733
+ schema_metadata: SchemaMetadata
1734
+ ) -> None:
1735
+ """Update knowledge store with schema metadata for RAG."""
1736
+ if not self.knowledge_store:
1737
+ return
1738
+
1739
+ documents = []
1740
+
1741
+ for table in schema_metadata.tables:
1742
+ table_name = table.get('table_name')
1743
+ columns = table.get('columns', [])
1744
+
1745
+ column_descriptions = [
1746
+ f"{col.get('column_name')} ({col.get('data_type', 'unknown')})"
1747
+ for col in columns
1748
+ ]
1749
+
1750
+ doc_text = f"""
1751
+ Table: {schema_metadata.schema_name}.{table_name}
1752
+ Database: {schema_metadata.database_flavor.value}
1753
+ Columns: {', '.join(column_descriptions)}
1754
+ """
1755
+
1756
+ documents.append({
1757
+ 'content': doc_text,
1758
+ 'metadata': {
1759
+ 'type': 'database_schema',
1760
+ 'schema': schema_metadata.schema_name,
1761
+ 'table': table_name,
1762
+ 'database_flavor': schema_metadata.database_flavor.value,
1763
+ }
1764
+ })
1765
+
1766
+ await self.knowledge_store.add_documents(documents)
1767
+
1768
+ def _convert_to_structured_output(
1769
+ self,
1770
+ raw_results: Any,
1771
+ schema: Dict[str, Any]
1772
+ ) -> List[Dict[str, Any]]:
1773
+ """Convert raw query results to structured output."""
1774
+ if not raw_results:
1775
+ return []
1776
+
1777
+ field_mappings = schema.get('field_mappings', {})
1778
+ structured_results = []
1779
+
1780
+ for row in raw_results:
1781
+ structured_row = {}
1782
+
1783
+ if isinstance(row, dict):
1784
+ if field_mappings:
1785
+ for target, source in field_mappings.items():
1786
+ if source in row:
1787
+ structured_row[target] = row[source]
1788
+ elif target in row:
1789
+ structured_row[target] = row[target]
1790
+ else:
1791
+ structured_row = row
1792
+ else:
1793
+ fields = schema.get('fields', [])
1794
+ for i, field in enumerate(fields):
1795
+ if i < len(row):
1796
+ structured_row[field] = row[i]
1797
+
1798
+ structured_results.append(structured_row)
1799
+
1800
+ return structured_results