ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
parrot/bots/db/sql.py ADDED
@@ -0,0 +1,1255 @@
1
+ """
2
+ Enhanced SQL Database Agent Implementation for AI-Parrot.
3
+
4
+ Concrete implementation of AbstractDbAgent for SQL databases with support for:
5
+ - PostgreSQL, MySQL, and SQL Server
6
+ - Dictionary and string credentials
7
+ - Dual DSN generation for SQLAlchemy and asyncdb
8
+ - DatabaseQueryTool integration for query validation and execution
9
+ """
10
+
11
+ from typing import Dict, Any, List, Optional, Union
12
+ import re
13
+ from urllib.parse import urlparse, quote_plus
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
17
+ from sqlalchemy.orm import sessionmaker
18
+ from sqlalchemy import text
19
+ from .abstract import AbstractDBAgent
20
+ from .tools import DatabaseSchema, TableMetadata
21
+ from ...models import AIMessage
22
+ from ...tools.databasequery import DatabaseQueryTool
23
+ from ...tools import ToolResult
24
+
25
+
26
+ class SQLAgent(AbstractDBAgent):
27
+ """
28
+ SQL Database Agent with dual DSN support and DatabaseQueryTool integration.
29
+
30
+ Supports PostgreSQL, MySQL, and SQL Server with both dictionary and string credentials.
31
+ """
32
+
33
+ # Database flavor mappings for SQLAlchemy
34
+ SQLALCHEMY_DIALECT_MAPPING = {
35
+ 'postgresql': 'postgresql+asyncpg',
36
+ 'pg': 'postgresql+asyncpg',
37
+ 'postgres': 'postgresql+asyncpg',
38
+ 'mysql': 'mysql+aiomysql',
39
+ 'sqlserver': 'mssql+aioodbc',
40
+ 'mssql': 'mssql+aioodbc'
41
+ }
42
+
43
+ # Default ports for databases
44
+ DEFAULT_PORTS = {
45
+ 'postgresql': 5432,
46
+ 'postgres': 5432,
47
+ 'mysql': 3306,
48
+ 'sqlserver': 1433,
49
+ 'mssql': 1433
50
+ }
51
+
52
+ def __init__(
53
+ self,
54
+ name: str = "SQLAgent",
55
+ credentials: Union[str, Dict[str, Any]] = None,
56
+ database_flavor: str = "postgresql",
57
+ schema_name: str = "public",
58
+ max_sample_rows: int = 2,
59
+ **kwargs
60
+ ):
61
+ """
62
+ Initialize SQL Database Agent.
63
+
64
+ Args:
65
+ name: Agent name
66
+ credentials: Connection credentials (dict or connection string)
67
+ database_flavor: Database type (postgresql, mysql, sqlserver)
68
+ schema_name: Target schema name
69
+ max_sample_rows: Maximum rows to sample from each table
70
+ """
71
+ self.database_flavor = database_flavor.lower()
72
+ self.max_sample_rows = max_sample_rows
73
+ self.async_session_maker = None
74
+
75
+ # DSN strings for different purposes
76
+ self.discovery_dsn = None # SQLAlchemy format for schema discovery
77
+ self.dsn = None # asyncdb format for DatabaseQueryTool
78
+ self.credentials = None
79
+ self.connection_dict = None
80
+ if isinstance(credentials, dict):
81
+ self.connection_dict = credentials
82
+
83
+ # Validate database flavor
84
+ if self.database_flavor not in self.SQLALCHEMY_DIALECT_MAPPING:
85
+ raise ValueError(
86
+ f"Unsupported database flavor: {database_flavor}"
87
+ )
88
+
89
+ # Force low temperature to minimize hallucinations
90
+ kwargs['temperature'] = kwargs.get('temperature', 0.0)
91
+
92
+ super().__init__(
93
+ name=name,
94
+ credentials=credentials,
95
+ schema_name=schema_name,
96
+ **kwargs
97
+ )
98
+
99
+ # Process credentials and generate DSNs
100
+ self._process_credentials(credentials)
101
+
102
+ # Add SQL-specific tools
103
+ self._setup_sql_tools()
104
+
105
+ def _dsn_for_sqlalchemy(self, connection_string: str) -> str:
106
+ """Adapt connection string for SQLAlchemy async drivers."""
107
+ parsed = urlparse(connection_string)
108
+
109
+ if parsed.scheme.startswith('postgresql') and '+asyncpg' not in parsed.scheme:
110
+ return connection_string.replace('postgresql://', 'postgresql+asyncpg://')
111
+ elif parsed.scheme.startswith('postgres') and '+asyncpg' not in parsed.scheme:
112
+ return connection_string.replace('postgres://', 'postgresql+asyncpg://')
113
+ elif parsed.scheme.startswith('mysql') and '+aiomysql' not in parsed.scheme:
114
+ return connection_string.replace('mysql://', 'mysql+aiomysql://')
115
+ elif parsed.scheme.startswith('mssql') and '+aioodbc' not in parsed.scheme:
116
+ return connection_string.replace('mssql://', 'mssql+aioodbc://')
117
+
118
+ return connection_string
119
+
120
+ def _dsn_for_asyncdb(self, connection_string: str) -> str:
121
+ """Adapt connection string for asyncdb format."""
122
+ parsed = urlparse(connection_string)
123
+
124
+ # Check if already in asyncdb format:
125
+ if parsed.scheme in ['postgres', 'mysql', 'mssql']:
126
+ return connection_string
127
+
128
+ # Convert SQLAlchemy formats to asyncdb formats
129
+ if parsed.scheme.startswith('postgresql'):
130
+ return connection_string.replace(
131
+ 'postgresql+asyncpg://', 'postgres://'
132
+ ).replace('postgresql://', 'postgres://')
133
+ elif parsed.scheme.startswith('mysql'):
134
+ return connection_string.replace(
135
+ 'mysql+aiomysql://', 'mysql://'
136
+ ).replace('mysql://', 'mysql://')
137
+ elif parsed.scheme.startswith('mssql'):
138
+ return connection_string.replace(
139
+ 'mssql+aioodbc://', 'mssql://'
140
+ ).replace('mssql://', 'mssql://')
141
+
142
+ return connection_string
143
+
144
+ def _process_credentials(self, credentials: Union[str, Dict[str, Any]]) -> None:
145
+ """
146
+ Process credentials and generate both discovery_dsn and dsn.
147
+
148
+ Args:
149
+ credentials: Either connection string or dictionary with connection params
150
+ """
151
+ if isinstance(credentials, str):
152
+ # Connection string provided
153
+ self.connection_string = credentials
154
+ self.discovery_dsn = self._dsn_for_sqlalchemy(credentials)
155
+ self.dsn = self._dsn_for_asyncdb(credentials)
156
+ self.credentials = {}
157
+ elif isinstance(credentials, dict):
158
+ # Dictionary credentials provided
159
+ self.connection_dict = credentials
160
+ self.discovery_dsn = self._build_sqlalchemy_dsn_from_dict(credentials)
161
+ self.dsn = self._build_asyncdb_dsn_from_dict(credentials)
162
+ self.connection_string = self.discovery_dsn
163
+ self.credentials = credentials
164
+ else:
165
+ raise ValueError(
166
+ "Credentials must be either a connection string or dictionary"
167
+ )
168
+
169
+ def _build_sqlalchemy_dsn_from_dict(self, creds: Dict[str, Any]) -> str:
170
+ """
171
+ Build SQLAlchemy DSN from credentials dictionary.
172
+
173
+ Args:
174
+ creds: Dictionary with keys like host, port, database, username, password
175
+
176
+ Returns:
177
+ SQLAlchemy-compatible connection string
178
+ """
179
+ # Extract credentials with defaults
180
+ host = creds.get('host', 'localhost')
181
+ port = creds.get('port', self.DEFAULT_PORTS.get(self.database_flavor, 5432))
182
+ database = creds.get('database', creds.get('dbname', 'postgres'))
183
+ username = creds.get('username', creds.get('user', 'postgres'))
184
+ password = creds.get('password', creds.get('pwd', ''))
185
+
186
+ # URL encode password to handle special characters
187
+ encoded_password = quote_plus(str(password)) if password else ''
188
+
189
+ # Get SQLAlchemy dialect
190
+ dialect = self.SQLALCHEMY_DIALECT_MAPPING[self.database_flavor]
191
+
192
+ # Build connection string
193
+ if encoded_password:
194
+ dsn = f"{dialect}://{username}:{encoded_password}@{host}:{port}/{database}"
195
+ else:
196
+ dsn = f"{dialect}://{username}@{host}:{port}/{database}"
197
+
198
+ # Add any additional parameters
199
+ params = []
200
+ for key, value in creds.items():
201
+ if key not in ['host', 'port', 'database', 'dbname', 'username', 'user', 'password', 'pwd']:
202
+ params.append(f"{key}={value}")
203
+
204
+ if params:
205
+ dsn += "?" + "&".join(params)
206
+
207
+ return dsn
208
+
209
+ def _build_asyncdb_dsn_from_dict(self, creds: Dict[str, Any]) -> str:
210
+ """
211
+ Build asyncdb DSN from credentials dictionary.
212
+
213
+ Args:
214
+ creds: Dictionary with connection parameters
215
+
216
+ Returns:
217
+ asyncdb-compatible connection string (postgres://...)
218
+ """
219
+ # Extract credentials
220
+ host = creds.get('host', 'localhost')
221
+ port = creds.get('port', self.DEFAULT_PORTS.get(self.database_flavor, 5432))
222
+ database = creds.get('database', creds.get('dbname', 'postgres'))
223
+ username = creds.get('username', creds.get('user', 'postgres'))
224
+ password = creds.get('password', creds.get('pwd', ''))
225
+
226
+ # URL encode password
227
+ encoded_password = quote_plus(str(password)) if password else ''
228
+
229
+ # Get asyncdb scheme (postgres for PostgreSQL regardless of flavor name)
230
+ if self.database_flavor in ['postgresql', 'postgres']:
231
+ scheme = 'postgres'
232
+ elif self.database_flavor == 'mysql':
233
+ scheme = 'mysql'
234
+ elif self.database_flavor in ['sqlserver', 'mssql']:
235
+ scheme = 'mssql'
236
+ else:
237
+ scheme = 'postgres' # Default fallback
238
+
239
+ # Build DSN
240
+ if encoded_password:
241
+ dsn = f"{scheme}://{username}:{encoded_password}@{host}:{port}/{database}"
242
+ else:
243
+ dsn = f"{scheme}://{username}@{host}:{port}/{database}"
244
+
245
+ return dsn
246
+
247
+ def _setup_sql_tools(self):
248
+ """Setup SQL-specific tools including DatabaseQueryTool."""
249
+ # The DatabaseQueryTool should already be registered in the parent class
250
+ # We just need to ensure it's configured properly
251
+ pass
252
+
253
+ async def connect_database(self) -> None:
254
+ """Connect to the SQL database using SQLAlchemy async engine."""
255
+ if not self.discovery_dsn:
256
+ raise ValueError("Discovery DSN is required")
257
+
258
+ try:
259
+ # Create async engine for schema discovery
260
+ self.engine = create_async_engine(
261
+ self.discovery_dsn,
262
+ echo=False,
263
+ pool_pre_ping=True,
264
+ pool_recycle=3600
265
+ )
266
+
267
+ # Create session maker
268
+ self.async_session_maker = sessionmaker(
269
+ self.engine,
270
+ class_=AsyncSession,
271
+ expire_on_commit=False
272
+ )
273
+
274
+ # Test connection
275
+ async with self.engine.begin() as conn:
276
+ await conn.execute(text("SELECT 1"))
277
+
278
+ self.logger.info(
279
+ f"Successfully connected to {self.database_flavor} database using SQLAlchemy"
280
+ )
281
+
282
+ # Test DatabaseQueryTool connection
283
+ await self._test_database_query_tool()
284
+
285
+ except Exception as e:
286
+ self.logger.error(f"Failed to connect to database: {e}")
287
+ raise
288
+
289
+ async def _test_database_query_tool(self) -> None:
290
+ """Test DatabaseQueryTool connection."""
291
+ try:
292
+ # Get database query tool from registered tools
293
+ db_tool = self.tool_manager.get_tool('database_query')
294
+ if db_tool:
295
+ # Test with a simple query
296
+ test_result = await db_tool.execute(
297
+ driver='pg' if self.database_flavor in ['postgresql', 'postgres', 'pg'] else self.database_flavor,
298
+ query="SELECT 1 as test_column LIMIT 1",
299
+ dsn=self.dsn,
300
+ credentials=self.credentials or None,
301
+ output_format='native'
302
+ )
303
+
304
+ if test_result.status == "success":
305
+ self.logger.debug(
306
+ "DatabaseQueryTool connection test successful"
307
+ )
308
+ else:
309
+ self.logger.warning(
310
+ f"DatabaseQueryTool test failed: {test_result.error}"
311
+ )
312
+ else:
313
+ self.logger.warning(
314
+ "DatabaseQueryTool not found in registered tools"
315
+ )
316
+
317
+ except Exception as e:
318
+ self.logger.warning(
319
+ f"DatabaseQueryTool test failed: {e}"
320
+ )
321
+
322
+ async def extract_schema_metadata(self) -> DatabaseSchema:
323
+ """Extract complete schema metadata from SQL database."""
324
+ if not self.engine:
325
+ await self.connect_database()
326
+
327
+ try:
328
+ async with self.engine.begin() as conn:
329
+ # Get database name
330
+ db_name_query = await self._get_database_name_query()
331
+ result = await conn.execute(text(db_name_query))
332
+ database_name = result.scalar()
333
+
334
+ # Extract tables metadata
335
+ tables = await self._extract_tables_metadata(conn)
336
+
337
+ # Extract views metadata (simplified for now)
338
+ views = []
339
+
340
+ schema_metadata = DatabaseSchema(
341
+ database_name=database_name or "unknown",
342
+ database_type=self.database_flavor,
343
+ tables=tables,
344
+ views=views,
345
+ functions=[],
346
+ procedures=[],
347
+ metadata={
348
+ "schema_name": self.schema_name,
349
+ "extraction_timestamp": datetime.now().isoformat(),
350
+ "total_tables": len(tables),
351
+ "total_views": len(views),
352
+ "discovery_dsn": self.discovery_dsn,
353
+ "asyncdb_dsn": self.dsn
354
+ }
355
+ )
356
+
357
+ self.logger.info(
358
+ f"Extracted metadata for {len(tables)} tables"
359
+ )
360
+
361
+ return schema_metadata
362
+
363
+ except Exception as e:
364
+ self.logger.error(f"Failed to extract schema metadata: {e}")
365
+ raise
366
+
367
+ async def _get_database_name_query(self) -> str:
368
+ """Get database name query based on database flavor."""
369
+ if self.database_flavor in ['postgresql', 'postgres']:
370
+ return "SELECT current_database()"
371
+ elif self.database_flavor == 'mysql':
372
+ return "SELECT database()"
373
+ elif self.database_flavor in ['sqlserver', 'mssql']:
374
+ return "SELECT DB_NAME()"
375
+ else:
376
+ return "SELECT 'unknown' as database_name"
377
+
378
+ async def _extract_tables_metadata(self, conn) -> List[TableMetadata]:
379
+ """Extract metadata for all tables in the schema."""
380
+ tables = []
381
+
382
+ # Get table names
383
+ if self.database_flavor in ['postgresql', 'postgres']:
384
+ table_query = """
385
+ SELECT table_name
386
+ FROM information_schema.tables
387
+ WHERE table_schema = :schema_name
388
+ AND table_type = 'BASE TABLE'
389
+ ORDER BY table_name
390
+ """
391
+ elif self.database_flavor == 'mysql':
392
+ table_query = """
393
+ SELECT table_name
394
+ FROM information_schema.tables
395
+ WHERE table_schema = :schema_name
396
+ AND table_type = 'BASE TABLE'
397
+ ORDER BY table_name
398
+ """
399
+ else: # SQL Server
400
+ table_query = """
401
+ SELECT table_name
402
+ FROM information_schema.tables
403
+ WHERE table_schema = :schema_name
404
+ AND table_type = 'BASE TABLE'
405
+ ORDER BY table_name
406
+ """
407
+
408
+ result = await conn.execute(
409
+ text(table_query), {"schema_name": self.schema_name}
410
+ )
411
+ table_rows = result.fetchall()
412
+
413
+ for row in table_rows:
414
+ table_name = row[0]
415
+ # Extract detailed table metadata
416
+ table_metadata = await self._extract_single_table_metadata(conn, table_name)
417
+ tables.append(table_metadata)
418
+
419
+ return tables
420
+
421
+ async def _extract_single_table_metadata(self, conn, table_name: str) -> TableMetadata:
422
+ """Extract detailed metadata for a single table."""
423
+ # Get column information
424
+ columns = await self._get_table_columns(conn, table_name)
425
+
426
+ # Get primary keys
427
+ primary_keys = await self._get_primary_keys(conn, table_name)
428
+
429
+ # Get foreign keys
430
+ foreign_keys = await self._get_foreign_keys(conn, table_name)
431
+
432
+ # Get sample data using DatabaseQueryTool
433
+ sample_data = await self._get_sample_data_via_tool(table_name)
434
+
435
+ return TableMetadata(
436
+ name=table_name,
437
+ schema=self.schema_name,
438
+ columns=columns,
439
+ primary_keys=primary_keys,
440
+ foreign_keys=foreign_keys,
441
+ indexes=[], # Simplified for now
442
+ description=None, # Simplified for now
443
+ sample_data=sample_data
444
+ )
445
+
446
+ async def _get_table_columns(self, conn, table_name: str) -> List[Dict[str, Any]]:
447
+ """Get column information for a table."""
448
+ if self.database_flavor in ['postgresql', 'postgres']:
449
+ query = """
450
+ SELECT
451
+ column_name,
452
+ data_type,
453
+ is_nullable,
454
+ column_default,
455
+ character_maximum_length,
456
+ numeric_precision,
457
+ numeric_scale
458
+ FROM information_schema.columns
459
+ WHERE table_schema = :schema_name
460
+ AND table_name = :table_name
461
+ ORDER BY ordinal_position
462
+ """
463
+ elif self.database_flavor == 'mysql':
464
+ query = """
465
+ SELECT
466
+ column_name,
467
+ data_type,
468
+ is_nullable,
469
+ column_default,
470
+ character_maximum_length,
471
+ numeric_precision,
472
+ numeric_scale
473
+ FROM information_schema.columns
474
+ WHERE table_schema = :schema_name
475
+ AND table_name = :table_name
476
+ ORDER BY ordinal_position
477
+ """
478
+ else: # SQL Server
479
+ query = """
480
+ SELECT
481
+ column_name,
482
+ data_type,
483
+ is_nullable,
484
+ column_default,
485
+ character_maximum_length,
486
+ numeric_precision,
487
+ numeric_scale
488
+ FROM information_schema.columns
489
+ WHERE table_schema = :schema_name
490
+ AND table_name = :table_name
491
+ ORDER BY ordinal_position
492
+ """
493
+
494
+ result = await conn.execute(text(query), {
495
+ "schema_name": self.schema_name,
496
+ "table_name": table_name
497
+ })
498
+
499
+ columns = []
500
+ for row in result.fetchall():
501
+ columns.append({
502
+ "name": row[0],
503
+ "type": row[1],
504
+ "nullable": row[2] == "YES",
505
+ "default": row[3],
506
+ "max_length": row[4],
507
+ "precision": row[5],
508
+ "scale": row[6]
509
+ })
510
+
511
+ return columns
512
+
513
+ async def _get_primary_keys(self, conn, table_name: str) -> List[str]:
514
+ """Get primary key columns for a table."""
515
+ if self.database_flavor in ['postgresql', 'postgres']:
516
+ query = """
517
+ SELECT column_name
518
+ FROM information_schema.key_column_usage
519
+ WHERE table_schema = :schema_name
520
+ AND table_name = :table_name
521
+ AND constraint_name IN (
522
+ SELECT constraint_name
523
+ FROM information_schema.table_constraints
524
+ WHERE table_schema = :schema_name
525
+ AND table_name = :table_name
526
+ AND constraint_type = 'PRIMARY KEY'
527
+ )
528
+ ORDER BY ordinal_position
529
+ """
530
+ else: # MySQL and SQL Server
531
+ query = """
532
+ SELECT column_name
533
+ FROM information_schema.key_column_usage
534
+ WHERE table_schema = :schema_name
535
+ AND table_name = :table_name
536
+ AND constraint_name = 'PRIMARY'
537
+ ORDER BY ordinal_position
538
+ """
539
+
540
+ result = await conn.execute(text(query), {
541
+ "schema_name": self.schema_name,
542
+ "table_name": table_name
543
+ })
544
+
545
+ return [row[0] for row in result.fetchall()]
546
+
547
+ async def _get_foreign_keys(self, conn, table_name: str) -> List[Dict[str, Any]]:
548
+ """Get foreign key information for a table."""
549
+ query = """
550
+ SELECT
551
+ kcu.column_name,
552
+ ccu.table_schema AS referenced_table_schema,
553
+ ccu.table_name AS referenced_table_name,
554
+ ccu.column_name AS referenced_column_name
555
+ FROM information_schema.key_column_usage kcu
556
+ JOIN information_schema.constraint_column_usage ccu
557
+ ON kcu.constraint_name = ccu.constraint_name
558
+ WHERE kcu.table_schema = :schema_name
559
+ AND kcu.table_name = :table_name
560
+ AND kcu.constraint_name IN (
561
+ SELECT constraint_name
562
+ FROM information_schema.table_constraints
563
+ WHERE table_schema = :schema_name
564
+ AND table_name = :table_name
565
+ AND constraint_type = 'FOREIGN KEY'
566
+ )
567
+ """
568
+
569
+ result = await conn.execute(text(query), {
570
+ "schema_name": self.schema_name,
571
+ "table_name": table_name
572
+ })
573
+
574
+ foreign_keys = []
575
+ for row in result.fetchall():
576
+ foreign_keys.append({
577
+ "column": row[0],
578
+ "referenced_table_schema": row[1],
579
+ "referenced_table": row[2],
580
+ "referenced_column": row[3]
581
+ })
582
+
583
+ return foreign_keys
584
+
585
+ async def _get_sample_data_via_tool(self, table_name: str) -> List[Dict[str, Any]]:
586
+ """Get sample data using DatabaseQueryTool."""
587
+ try:
588
+ # Get database query tool
589
+ db_tool = self.tool_manager.get_tool('database_query')
590
+ if not db_tool:
591
+ self.logger.warning("DatabaseQueryTool not found")
592
+ return []
593
+
594
+ # Build sample query
595
+ full_table_name = f'"{self.schema_name}"."{table_name}"' if self.schema_name != 'public' else f'"{table_name}"'
596
+ sample_query = f"SELECT * FROM {full_table_name} LIMIT {self.max_sample_rows}"
597
+
598
+ # Execute query
599
+ result = await db_tool.execute(
600
+ driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
601
+ query=sample_query,
602
+ dsn=self.dsn,
603
+ credentials=self.connection_dict,
604
+ output_format='json'
605
+ )
606
+ if result.status == "success":
607
+ return result.result
608
+ else:
609
+ self.logger.warning(f"Could not get sample data for {table_name}: {result.error}")
610
+ return []
611
+
612
+ except Exception as e:
613
+ self.logger.warning(f"Error getting sample data for {table_name}: {e}")
614
+ return []
615
+
616
+ async def generate_query(
617
+ self,
618
+ natural_language_query: str,
619
+ target_tables: Optional[List[str]] = None,
620
+ query_type: str = "SELECT"
621
+ ) -> Dict[str, Any]:
622
+ """Generate SQL query from natural language and validate it."""
623
+ try:
624
+ # Get schema context
625
+ schema_context = await self._get_schema_context_for_query(
626
+ natural_language_query, target_tables
627
+ )
628
+
629
+ # Build prompt for LLM
630
+ prompt = self._build_query_generation_prompt(
631
+ natural_language_query=natural_language_query,
632
+ schema_context=schema_context,
633
+ query_type=query_type,
634
+ database_flavor=self.database_flavor
635
+ )
636
+
637
+ # Generate query using LLM
638
+ response = await self._llm.ask(
639
+ prompt=prompt,
640
+ model=self._llm_model,
641
+ temperature=0.0, # Zero temperature for deterministic results
642
+ use_tools=False, # Explicitly disable tools to prevent recursion
643
+ tools=[]
644
+ )
645
+
646
+ # Extract SQL query from response
647
+ generated_query = self._extract_sql_from_response(str(response.output))
648
+
649
+ # Validate query using DatabaseQueryTool with LIMIT 0
650
+ validation_result = await self._validate_query_with_tool(generated_query)
651
+
652
+ result = {
653
+ "query": generated_query,
654
+ "query_type": query_type,
655
+ "tables_used": self._extract_tables_from_query(generated_query),
656
+ "schema_context_used": len(schema_context),
657
+ "validation": validation_result,
658
+ "natural_language_input": natural_language_query
659
+ }
660
+
661
+ return result
662
+
663
+ except Exception as e:
664
+ self.logger.error(f"Failed to generate query: {e}")
665
+ raise
666
+
667
+ async def _validate_query_with_tool(self, query: str) -> Dict[str, Any]:
668
+ """Validate query using DatabaseQueryTool with LIMIT 0."""
669
+ try:
670
+ # Get database query tool
671
+ db_tool = None
672
+ for tool in self.tools:
673
+ if isinstance(tool, DatabaseQueryTool):
674
+ db_tool = tool
675
+ break
676
+
677
+ if not db_tool:
678
+ return {
679
+ "valid": False,
680
+ "error": "DatabaseQueryTool not available",
681
+ "method": "tool_validation"
682
+ }
683
+
684
+ # Modify query to add LIMIT 0 for validation (no data returned)
685
+ if query.strip().upper().startswith('SELECT'):
686
+ validation_query = f"SELECT * FROM ({query.rstrip(';')}) AS validation_subquery LIMIT 0"
687
+ else:
688
+ # For non-SELECT queries, we can't easily validate without risk
689
+ validation_query = query
690
+
691
+ # Execute validation query
692
+ result = await db_tool.execute(
693
+ driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
694
+ query=validation_query,
695
+ dsn=self.dsn,
696
+ credentials=self.connection_dict,
697
+ output_format='native'
698
+ )
699
+
700
+ return {
701
+ "valid": result.status == "success",
702
+ "error": result.error if result.status == "error" else None,
703
+ "method": "database_query_tool",
704
+ "validation_query": validation_query
705
+ }
706
+
707
+ except Exception as e:
708
+ return {
709
+ "valid": False,
710
+ "error": str(e),
711
+ "method": "tool_validation"
712
+ }
713
+
714
+ async def explain_query(self, query: str) -> str:
715
+ """
716
+ Explain a database query (e.g. EXPLAIN ANALYZE).
717
+
718
+ Args:
719
+ query: The SQL query to explain
720
+
721
+ Returns:
722
+ The execution plan as a string
723
+ """
724
+ try:
725
+ # Construct EXPLAIN query based on flavor
726
+ if self.database_flavor in ['postgresql', 'postgres', 'pg']:
727
+ # Use JSON format for better parsing if needed, and ANALYZE for actual execution stats
728
+ explain_query = f"EXPLAIN (FORMAT JSON, ANALYZE) {query}"
729
+ elif self.database_flavor == 'mysql':
730
+ explain_query = f"EXPLAIN ANALYZE {query}"
731
+ else:
732
+ explain_query = f"EXPLAIN {query}"
733
+
734
+ # Execute the explain query
735
+ # We use execute_query but need to handle the result format
736
+ result = await self.execute_query(explain_query, limit=0) # limit=0 is ignored for EXPLAIN usually
737
+
738
+ if result["success"]:
739
+ # Format the result
740
+ data = result["data"]
741
+ if self.database_flavor in ['postgresql', 'postgres', 'pg']:
742
+ # Postgres JSON output usually comes as a single cell with lists
743
+ try:
744
+ # It might be a list of dicts in the first column
745
+ plan = data.iloc[0, 0]
746
+ if isinstance(plan, list) or isinstance(plan, dict):
747
+ return json.dumps(plan, indent=2)
748
+ return str(plan)
749
+ except Exception:
750
+ return data.to_string()
751
+ else:
752
+ return data.to_string()
753
+ else:
754
+ return f"Failed to explain query: {result['error']}"
755
+
756
+ except Exception as e:
757
+ self.logger.error(f"Error explaining query: {e}")
758
+ return f"Error explaining query: {str(e)}"
759
+
760
+ async def execute_query(self, query: str, limit: int = 200) -> Dict[str, Any]:
761
+ """Execute SQL query and return results using DatabaseQueryTool."""
762
+ try:
763
+ # Get database query tool
764
+ db_tool = self.tool_manager.get_tool('database_query')
765
+ if not db_tool:
766
+ db_tool = None
767
+
768
+ if not db_tool:
769
+ return {
770
+ "success": False,
771
+ "error": "DatabaseQueryTool not available",
772
+ "query": query
773
+ }
774
+
775
+ # Add limit for SELECT queries if not present
776
+ execution_query = query
777
+ result = None
778
+ if query.strip().upper().startswith('SELECT') and 'LIMIT' not in query.upper():
779
+ execution_query = f"{query.rstrip(';')} LIMIT {limit}"
780
+
781
+ # Execute query (return a ToolResult)
782
+ result = await db_tool.execute(
783
+ driver='pg' if self.database_flavor in ['postgresql', 'postgres'] else self.database_flavor,
784
+ query=execution_query,
785
+ dsn=self.dsn,
786
+ credentials=self.connection_dict,
787
+ output_format='pandas'
788
+ )
789
+
790
+ if result.status == "success":
791
+ data = result.result
792
+ columns = data.columns.tolist() if not data.empty else []
793
+ row_count = len(data) if not data.empty else 0
794
+ return {
795
+ "success": True,
796
+ "data": data,
797
+ "columns": columns,
798
+ "row_count": row_count,
799
+ "query": execution_query,
800
+ "tool_used": "DatabaseQueryTool",
801
+ "raw_result": result
802
+ }
803
+ else:
804
+ return {
805
+ "success": False,
806
+ "error": result.error,
807
+ "query": execution_query,
808
+ "tool_used": "DatabaseQueryTool",
809
+ "raw_result": result
810
+ }
811
+
812
+ except Exception as e:
813
+ self.logger.error(f"Query execution failed: {e}")
814
+ return {
815
+ "success": False,
816
+ "error": str(e),
817
+ "query": query,
818
+ "tool_used": "DatabaseQueryTool",
819
+ "raw_result": None
820
+ }
821
+
822
+ async def _get_schema_context_for_query(
823
+ self,
824
+ natural_language_query: str,
825
+ target_tables: Optional[List[str]] = None
826
+ ) -> List[Dict[str, Any]]:
827
+ """Get relevant schema context for query generation."""
828
+ if target_tables:
829
+ context = []
830
+ for table_name in target_tables:
831
+ table_info = await self.search_schema(
832
+ search_term=table_name,
833
+ search_type="tables",
834
+ limit=1
835
+ )
836
+ if table_info:
837
+ context.extend(table_info)
838
+ return context
839
+ else:
840
+ return await self.search_schema(
841
+ search_term=natural_language_query,
842
+ search_type="all",
843
+ limit=5
844
+ )
845
+
846
+ def _build_query_generation_prompt(
847
+ self,
848
+ natural_language_query: str,
849
+ schema_context: List[Dict[str, Any]],
850
+ query_type: str,
851
+ database_flavor: str
852
+ ) -> str:
853
+ """Build prompt for LLM query generation."""
854
+ prompt = f"""
855
+ You are an expert SQL developer working with a {database_flavor} database.
856
+ Generate a clean, efficient {query_type} SQL query based on the natural language request and schema information.
857
+
858
+ Natural Language Request: {natural_language_query}
859
+
860
+ Available Schema Information:
861
+ """
862
+
863
+ for i, context in enumerate(schema_context[:3], 1):
864
+ prompt += f"\n{i}. {context.get('content', '')}\n"
865
+
866
+ prompt += f"""
867
+ Requirements:
868
+ 1. Generate valid {database_flavor} SQL with clean formatting
869
+ 2. Use appropriate {database_flavor} syntax and functions
870
+ 3. Use simple column names unless JOINs require qualification
871
+ 4. Use table aliases for readability in JOINs
872
+ 5. Only use double quotes for identifiers with special characters
873
+ 6. Include appropriate WHERE clauses and filters
874
+ 7. Optimize for performance and readability
875
+ 8. Return ONLY the SQL query without explanations or formatting
876
+
877
+ Query Type: {query_type}
878
+ Database: {database_flavor}
879
+
880
+ SQL Query:"""
881
+
882
+ return prompt
883
+
884
+ def _extract_sql_from_response(self, response_text: str) -> str:
885
+ """Extract SQL query from LLM response."""
886
+ # Remove markdown code blocks if present
887
+ if "```sql" in response_text:
888
+ lines = response_text.split('\n')
889
+ sql_lines = []
890
+ in_sql_block = False
891
+
892
+ for line in lines:
893
+ if line.strip().startswith("```sql"):
894
+ in_sql_block = True
895
+ continue
896
+ elif line.strip() == "```" and in_sql_block:
897
+ break
898
+ elif in_sql_block:
899
+ sql_lines.append(line)
900
+
901
+ return '\n'.join(sql_lines).strip()
902
+ else:
903
+ return response_text.strip()
904
+
905
+ async def ask(
906
+ self,
907
+ question: str = None,
908
+ user_context: str = "",
909
+ context: str = "",
910
+ return_results: bool = True, # New parameter to control query execution
911
+ session_id: Optional[str] = None,
912
+ user_id: Optional[str] = None,
913
+ use_conversation_history: bool = True,
914
+ **kwargs
915
+ ) -> AIMessage:
916
+ """
917
+ Enhanced ask method that can automatically execute generated SQL queries.
918
+
919
+ Args:
920
+ question: The user's question about the database
921
+ user_context: User-specific context for database interaction
922
+ context: Additional context about data location, schema guidance
923
+ return_results: If True, automatically execute generated SQL queries and return data
924
+ session_id: Session identifier for conversation history
925
+ user_id: User identifier
926
+ use_conversation_history: Whether to use conversation history
927
+ **kwargs: Additional arguments for LLM
928
+
929
+ Returns:
930
+ AIMessage: The response from the LLM, potentially enhanced with query results
931
+ """
932
+ # Backwards compatibility
933
+ if question is None:
934
+ question = kwargs.get('prompt')
935
+
936
+ # First, get the standard response from the parent method
937
+ response = await super().ask(
938
+ question=question,
939
+ user_context=user_context,
940
+ context=context,
941
+ session_id=session_id,
942
+ user_id=user_id,
943
+ use_conversation_history=use_conversation_history,
944
+ **kwargs
945
+ )
946
+
947
+ # If return_results is False, return the response as-is
948
+ if not return_results:
949
+ return response
950
+
951
+ # Try to extract and execute SQL queries from the response
952
+ try:
953
+ response_text = str(response.output) if response.output else ""
954
+
955
+ # Extract SQL queries from the response
956
+ sql_queries = self._extract_queries(response_text)
957
+
958
+ if sql_queries:
959
+ # Execute the first/main SQL query
960
+ main_query = sql_queries[0]
961
+ self.logger.debug(
962
+ f"Auto-executing extracted query: {main_query[:100]}..."
963
+ )
964
+
965
+ # Execute the query
966
+ result = await self.execute_query(
967
+ query=main_query
968
+ )
969
+ # Preserve original response
970
+ response.response = response_text
971
+ # is the dataframe:
972
+ response.output = result.get('data', None)
973
+ response.raw_response = result # Preserve raw ToolResult
974
+
975
+ # Add execution metadata if response has metadata attribute
976
+ if hasattr(response, 'metadata') and response.metadata:
977
+ response.metadata.update({
978
+ 'auto_executed_query': True,
979
+ 'executed_query': main_query,
980
+ 'execution_success': result.get('status') == 'success',
981
+ 'row_count': result.get('row_count', 0),
982
+ 'columns': result.get('columns', []),
983
+ 'error': result.get('error', None)
984
+ })
985
+
986
+ except Exception as e:
987
+ self.logger.warning(
988
+ f"Failed to auto-execute query: {e}"
989
+ )
990
+ # Don't fail the entire request, just log the warning
991
+ # The user still gets the explanation even if execution fails
992
+
993
+ return response
994
+
995
+ async def search_schema(
996
+ self,
997
+ search_term: str,
998
+ search_type: str = "all",
999
+ limit: int = 10
1000
+ ) -> List[Dict[str, Any]]:
1001
+ """
1002
+ Search the database schema using SQL queries against information_schema.
1003
+
1004
+ Args:
1005
+ search_term: Term to search for (supports LIKE patterns implicitly)
1006
+ search_type: Type of search ('tables', 'columns', 'all')
1007
+ limit: Maximum number of results
1008
+
1009
+ Returns:
1010
+ List of matching schema objects
1011
+ """
1012
+ results = []
1013
+
1014
+ # Check cache first
1015
+ if self.cache:
1016
+ cached_results = await self.cache.get(search_term, search_type, limit)
1017
+ if cached_results is not None:
1018
+ self.logger.info(f"Schema search cache hit for term: {search_term}")
1019
+ return cached_results
1020
+
1021
+ term_pattern = f"%{search_term}%"
1022
+
1023
+ try:
1024
+ # Determine logic based on search_type
1025
+ search_tables = search_type in ["all", "tables"]
1026
+ search_columns = search_type in ["all", "columns"]
1027
+
1028
+ # --- Search Tables ---
1029
+ if search_tables:
1030
+ if self.database_flavor in ['postgresql', 'postgres', 'pg']:
1031
+ # Support schema.table search
1032
+ query = """
1033
+ SELECT table_schema, table_name, 'TABLE' as type
1034
+ FROM information_schema.tables
1035
+ WHERE (table_name ILIKE :term
1036
+ OR table_schema || '.' || table_name ILIKE :term)
1037
+ AND table_schema NOT IN ('information_schema', 'pg_catalog')
1038
+ AND table_type = 'BASE TABLE'
1039
+ LIMIT :limit
1040
+ """
1041
+ elif self.database_flavor == 'mysql':
1042
+ query = """
1043
+ SELECT table_schema, table_name, 'TABLE' as type
1044
+ FROM information_schema.tables
1045
+ WHERE (table_name LIKE :term
1046
+ OR CONCAT(table_schema, '.', table_name) LIKE :term)
1047
+ AND table_schema = DATABASE()
1048
+ AND table_type = 'BASE TABLE'
1049
+ LIMIT :limit
1050
+ """
1051
+ else: # Generic/SQL Server
1052
+ query = """
1053
+ SELECT table_schema, table_name, 'TABLE' as type
1054
+ FROM information_schema.tables
1055
+ WHERE table_name LIKE :term
1056
+ LIMIT :limit
1057
+ """
1058
+
1059
+ if self.engine:
1060
+ async with self.engine.connect() as conn:
1061
+ result_proxy = await conn.execute(text(query), {"term": term_pattern, "limit": limit})
1062
+ rows = result_proxy.fetchall()
1063
+ for row in rows:
1064
+ results.append({
1065
+ "type": "table",
1066
+ "name": row[1],
1067
+ "schema": row[0],
1068
+ "description": f"Table: {row[0]}.{row[1]}"
1069
+ })
1070
+
1071
+ # --- Search Columns ---
1072
+ if search_columns and len(results) < limit:
1073
+ current_limit = limit - len(results)
1074
+ if self.database_flavor in ['postgresql', 'postgres', 'pg']:
1075
+ query = """
1076
+ SELECT table_schema, table_name, column_name, data_type
1077
+ FROM information_schema.columns
1078
+ WHERE column_name ILIKE :term
1079
+ AND table_schema NOT IN ('information_schema', 'pg_catalog')
1080
+ LIMIT :limit
1081
+ """
1082
+ elif self.database_flavor == 'mysql':
1083
+ query = """
1084
+ SELECT table_schema, table_name, column_name, data_type
1085
+ FROM information_schema.columns
1086
+ WHERE column_name LIKE :term
1087
+ AND table_schema = DATABASE()
1088
+ LIMIT :limit
1089
+ """
1090
+ else: # Generic/SQL Server
1091
+ query = """
1092
+ SELECT table_schema, table_name, column_name, data_type
1093
+ FROM information_schema.columns
1094
+ WHERE column_name LIKE :term
1095
+ LIMIT :limit
1096
+ """
1097
+
1098
+ if self.engine:
1099
+ async with self.engine.connect() as conn:
1100
+ result_proxy = await conn.execute(text(query), {"term": term_pattern, "limit": current_limit})
1101
+ rows = result_proxy.fetchall()
1102
+ for row in rows:
1103
+ results.append({
1104
+ "type": "column",
1105
+ "table": row[1],
1106
+ "schema": row[0],
1107
+ "name": row[2],
1108
+ "description": f"Column: {row[2]} (Type: {row[3]}) in {row[0]}.{row[1]}",
1109
+ "metadata": f"Type: {row[3]}"
1110
+ })
1111
+
1112
+ # Cache the results ONLY if we found something
1113
+ # This prevents caching False Negatives (empty results) which might be due to transient issues or bad queries
1114
+ if self.cache and results:
1115
+ await self.cache.set(search_term, search_type, limit, results)
1116
+
1117
+ return results
1118
+
1119
+ except Exception as e:
1120
+ self.logger.error(f"Error in SQL-based search_schema: {e}")
1121
+ return []
1122
+
1123
+ def _extract_queries(self, response_text: str) -> List[str]:
1124
+ """
1125
+ Extract SQL queries from LLM response text.
1126
+
1127
+ Args:
1128
+ response_text: The full response text from the LLM
1129
+
1130
+ Returns:
1131
+ List of extracted SQL queries
1132
+ """
1133
+ queries = []
1134
+
1135
+ # Method 1: Extract from markdown code blocks
1136
+ sql_pattern = r'```sql\n(.*?)\n```'
1137
+ matches = re.findall(sql_pattern, response_text, re.DOTALL | re.IGNORECASE)
1138
+
1139
+ for match in matches:
1140
+ cleaned_query = match.strip()
1141
+ if cleaned_query and not cleaned_query.lower().startswith('--'):
1142
+ queries.append(cleaned_query)
1143
+
1144
+ # Method 2: If no markdown blocks, look for SQL-like patterns
1145
+ # CAUTION: This fallback generates false positives for explanations.
1146
+ # We will disable aggressive line scanning and only support markdown blocks or single-line exact queries.
1147
+ if not queries:
1148
+ cleaned_text = response_text.strip()
1149
+ # If the whole text looks like a query (starts with keyword, ends with ;)
1150
+ if re.match(r'^(SELECT|WITH|SHOW|DESCRIBE|EXPLAIN)\b.*?;$', cleaned_text, re.IGNORECASE | re.DOTALL):
1151
+ queries.append(cleaned_text)
1152
+
1153
+ # Clean up queries
1154
+ cleaned_queries = []
1155
+ for query in queries:
1156
+ # Remove common prefixes/suffixes
1157
+ query = re.sub(r'^```sql\s*', '', query, flags=re.IGNORECASE)
1158
+ query = re.sub(r'\s*```$', '', query)
1159
+ query = query.strip()
1160
+
1161
+ # Basic validation - should contain SELECT, WITH, etc.
1162
+ if re.search(r'\b(SELECT|WITH|SHOW|DESCRIBE|EXPLAIN)\b', query, re.IGNORECASE):
1163
+ cleaned_queries.append(query)
1164
+
1165
+ return cleaned_queries
1166
+
1167
+ def _extract_tables_from_query(self, query: str) -> List[str]:
1168
+ """Extract table names from SQL query."""
1169
+ pattern = r'(?:FROM|JOIN)\s+(?:[\w\.]*\.)?(\w+)'
1170
+ matches = re.findall(pattern, query.upper())
1171
+ return list(set(matches))
1172
+
1173
+ async def cleanup(self) -> None:
1174
+ """Cleanup resources."""
1175
+ if self.engine:
1176
+ await self.engine.dispose()
1177
+ await super().cleanup()
1178
+
1179
+
1180
+
1181
+
1182
+ # Factory function for creating enhanced SQL agents
1183
+ def create_sql_agent(
1184
+ database_flavor: str,
1185
+ credentials: Union[str, Dict[str, Any]],
1186
+ schema_name: str = None,
1187
+ **kwargs
1188
+ ) -> SQLAgent:
1189
+ """
1190
+ Factory function to create SQL database agents.
1191
+
1192
+ Args:
1193
+ database_flavor: Database type ('postgresql', 'mysql', 'sqlserver')
1194
+ credentials: Connection credentials (string or dict)
1195
+ schema_name: Target schema name
1196
+ **kwargs: Additional arguments
1197
+
1198
+ Returns:
1199
+ Configured SQLAgent instance
1200
+ """
1201
+ # Set default schema names
1202
+ if schema_name is None:
1203
+ if database_flavor.lower() in ['postgresql', 'postgres']:
1204
+ schema_name = 'public'
1205
+ elif database_flavor.lower() == 'mysql':
1206
+ schema_name = 'mysql'
1207
+ elif database_flavor.lower() in ['sqlserver', 'mssql']:
1208
+ schema_name = 'dbo'
1209
+ else:
1210
+ schema_name = 'public'
1211
+
1212
+ return SQLAgent(
1213
+ database_flavor=database_flavor,
1214
+ credentials=credentials,
1215
+ schema_name=schema_name,
1216
+ **kwargs
1217
+ )
1218
+
1219
+
1220
+ # Example usage
1221
+ """
1222
+ # Dictionary credentials example
1223
+ pg_creds = {
1224
+ 'host': 'localhost',
1225
+ 'port': 5432,
1226
+ 'database': 'sales_db',
1227
+ 'username': 'user',
1228
+ 'password': 'password'
1229
+ }
1230
+
1231
+ pg_agent = create_sql_agent(
1232
+ database_flavor='postgresql',
1233
+ credentials=pg_creds,
1234
+ schema_name='public'
1235
+ )
1236
+
1237
+ # Connection string example
1238
+ mysql_agent = create_sql_agent(
1239
+ database_flavor='mysql',
1240
+ credentials='mysql://user:pass@localhost/dbname'
1241
+ )
1242
+
1243
+ # Usage
1244
+ await pg_agent.initialize_schema()
1245
+
1246
+ # Generate and execute query
1247
+ query_result = await pg_agent.generate_query(
1248
+ "Show me all customers from the East region with their order totals"
1249
+ )
1250
+
1251
+ execution_result = await pg_agent.execute_query(query_result['query'])
1252
+ print(f"Query: {execution_result['query']}")
1253
+ print(f"Data: {execution_result['data']}")
1254
+ """
1255
+