ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
@@ -0,0 +1,4567 @@
1
+ import re
2
+ import sys
3
+ import asyncio
4
+ from datetime import datetime
5
+ from typing import Any, AsyncIterator, Dict, List, Optional, Union, Tuple
6
+ from functools import partial
7
+ import logging
8
+ import time
9
+ from pathlib import Path
10
+ import contextlib
11
+ import io
12
+ import uuid
13
+ import aiofiles
14
+ import aiohttp
15
+ from PIL import Image
16
+ from google import genai
17
+ from google.genai.types import (
18
+ GenerateContentConfig,
19
+ Part,
20
+ ModelContent,
21
+ UserContent,
22
+ )
23
+ from google.oauth2 import service_account
24
+ from google.genai import types
25
+ from navconfig import config, BASE_DIR
26
+ import pandas as pd
27
+ from sklearn.base import defaultdict
28
+ from .base import (
29
+ AbstractClient,
30
+ ToolDefinition,
31
+ RetryConfig,
32
+ TokenRetryMixin,
33
+ StreamingRetryConfig
34
+ )
35
+ from ..models import (
36
+ AIMessage,
37
+ AIMessageFactory,
38
+ ToolCall,
39
+ StructuredOutputConfig,
40
+ OutputFormat,
41
+ CompletionUsage,
42
+ ImageGenerationPrompt,
43
+ SpeakerConfig,
44
+ SpeechGenerationPrompt,
45
+ VideoGenerationPrompt,
46
+ ObjectDetectionResult,
47
+ GoogleModel,
48
+ TTSVoice
49
+ )
50
+ from ..tools.abstract import AbstractTool, ToolResult
51
+ from ..models.outputs import (
52
+ SentimentAnalysis,
53
+ ProductReview
54
+ )
55
+ from ..models.google import (
56
+ ALL_VOICE_PROFILES,
57
+ VoiceRegistry,
58
+ ConversationalScriptConfig,
59
+ FictionalSpeaker
60
+ )
61
+ from ..exceptions import SpeechGenerationError # pylint: disable=E0611
62
+ from ..models.detections import (
63
+ DetectionBox,
64
+ ShelfRegion,
65
+ IdentifiedProduct,
66
+ IdentificationResponse
67
+ )
68
+
69
+ logging.getLogger(
70
+ name='PIL.TiffImagePlugin'
71
+ ).setLevel(logging.ERROR) # Suppress TiffImagePlugin warnings
72
+ logging.getLogger(
73
+ name='google_genai'
74
+ ).setLevel(logging.WARNING) # Suppress GenAI warnings
75
+
76
+
77
+ class GoogleGenAIClient(AbstractClient):
78
+ """
79
+ Client for interacting with Google's Generative AI, with support for parallel function calling.
80
+
81
+ Only Gemini-2.5-pro works well with multi-turn function calling.
82
+ Supports both API Key (Gemini Developer API) and Service Account (Vertex AI).
83
+ """
84
+ client_type: str = 'google'
85
+ client_name: str = 'google'
86
+ _default_model: str = 'gemini-2.5-flash'
87
+ _model_garden: bool = False
88
+
89
+ def __init__(self, vertexai: bool = False, model_garden: bool = False, **kwargs):
90
+ self.model_garden = model_garden
91
+ self.vertexai: bool = True if model_garden else vertexai
92
+ self.vertex_location = kwargs.get('location', config.get('VERTEX_REGION'))
93
+ self.vertex_project = kwargs.get('project', config.get('VERTEX_PROJECT_ID'))
94
+ self._credentials_file = kwargs.get('credentials_file', config.get('VERTEX_CREDENTIALS_FILE'))
95
+ if isinstance(self._credentials_file, str):
96
+ self._credentials_file = Path(self._credentials_file).expanduser()
97
+ self.api_key = kwargs.pop('api_key', config.get('GOOGLE_API_KEY'))
98
+ super().__init__(**kwargs)
99
+ self.max_tokens = kwargs.get('max_tokens', 8192)
100
+ self.client = None
101
+ # Create a single instance of the Voice registry
102
+ self.voice_db = VoiceRegistry(profiles=ALL_VOICE_PROFILES)
103
+
104
+ async def get_client(self) -> genai.Client:
105
+ """Get the underlying Google GenAI client."""
106
+ if self.vertexai:
107
+ self.logger.info(
108
+ f"Initializing Vertex AI for project {self.vertex_project} in {self.vertex_location}"
109
+ )
110
+ try:
111
+ if self._credentials_file and self._credentials_file.exists():
112
+ credentials = service_account.Credentials.from_service_account_file(
113
+ str(self._credentials_file)
114
+ )
115
+ else:
116
+ credentials = None # Use default credentials
117
+
118
+ return genai.Client(
119
+ vertexai=True,
120
+ project=self.vertex_project,
121
+ location=self.vertex_location,
122
+ credentials=credentials
123
+ )
124
+ except Exception as exc:
125
+ self.logger.error(f"Failed to initialize Vertex AI client: {exc}")
126
+ raise
127
+ return genai.Client(
128
+ api_key=self.api_key
129
+ )
130
+
131
+ async def close(self):
132
+ if self.client:
133
+ with contextlib.suppress(Exception):
134
+ await self.client._api_client._aiohttp_session.close() # pylint: disable=E1101 # noqa
135
+ self.client = None
136
+
137
+ def _fix_tool_schema(self, schema: dict):
138
+ """Recursively converts schema type values to uppercase for GenAI compatibility."""
139
+ if isinstance(schema, dict):
140
+ for key, value in schema.items():
141
+ if key == 'type' and isinstance(value, str):
142
+ schema[key] = value.upper()
143
+ else:
144
+ self._fix_tool_schema(value)
145
+ elif isinstance(schema, list):
146
+ for item in schema:
147
+ self._fix_tool_schema(item)
148
+ return schema
149
+
150
+ def _analyze_prompt_for_tools(self, prompt: str) -> List[str]:
151
+ """
152
+ Analyze the prompt to determine which tools might be needed.
153
+ This is a placeholder for more complex logic that could analyze the prompt.
154
+ """
155
+ prompt_lower = prompt.lower()
156
+ # Keywords that suggest need for built-in tools
157
+ search_keywords = [
158
+ 'search',
159
+ 'find',
160
+ 'google',
161
+ 'web',
162
+ 'internet',
163
+ 'latest',
164
+ 'news',
165
+ 'weather'
166
+ ]
167
+ has_search_intent = any(keyword in prompt_lower for keyword in search_keywords)
168
+ if has_search_intent:
169
+ return "builtin_tools"
170
+ else:
171
+ # Mixed intent - prefer custom functions if available, otherwise builtin
172
+ return "custom_functions"
173
+
174
+ def _resolve_schema_refs(self, schema: dict, defs: dict = None) -> dict:
175
+ """
176
+ Recursively resolves $ref in JSON schema by inlining definitions.
177
+ This is crucial for Pydantic v2 schemas used with Gemini.
178
+ """
179
+ if defs is None:
180
+ defs = schema.get('$defs', schema.get('definitions', {}))
181
+
182
+ if not isinstance(schema, dict):
183
+ return schema
184
+
185
+ # Handle $ref
186
+ if '$ref' in schema:
187
+ ref_path = schema['$ref']
188
+ # Extract definition name (e.g., "#/$defs/MyModel" -> "MyModel")
189
+ def_name = ref_path.split('/')[-1]
190
+ if def_name in defs:
191
+ # Get the definition
192
+ resolved = self._resolve_schema_refs(defs[def_name], defs)
193
+ # Merge with any other properties in the current schema (rare but possible)
194
+ merged = {k: v for k, v in schema.items() if k != '$ref'}
195
+ merged.update(resolved)
196
+ return merged
197
+
198
+ # Process children
199
+ new_schema = {}
200
+ for key, value in schema.items():
201
+ if key == 'properties' and isinstance(value, dict):
202
+ new_schema[key] = {
203
+ k: self._resolve_schema_refs(v, defs)
204
+ for k, v in value.items()
205
+ }
206
+ elif key == 'items' and isinstance(value, dict):
207
+ new_schema[key] = self._resolve_schema_refs(value, defs)
208
+ elif key in ('anyOf', 'allOf', 'oneOf') and isinstance(value, list):
209
+ new_schema[key] = [self._resolve_schema_refs(item, defs) for item in value]
210
+ else:
211
+ new_schema[key] = value
212
+
213
+ return new_schema
214
+
215
+ def clean_google_schema(self, schema: dict) -> dict:
216
+ """
217
+ Clean a Pydantic-generated schema for Google Function Calling compatibility.
218
+ NOW INCLUDES: Reference resolution.
219
+ """
220
+ if not isinstance(schema, dict):
221
+ return schema
222
+
223
+ # 1. Resolve References FIRST
224
+ # Pydantic v2 uses $defs, v1 uses definitions
225
+ if '$defs' in schema or 'definitions' in schema:
226
+ schema = self._resolve_schema_refs(schema)
227
+
228
+ cleaned = {}
229
+
230
+ # Fields that Google Function Calling supports
231
+ supported_fields = {
232
+ 'type', 'description', 'enum', 'default', 'properties',
233
+ 'required', 'items'
234
+ }
235
+
236
+ # Copy supported fields
237
+ for key, value in schema.items():
238
+ if key in supported_fields:
239
+ if key == 'properties':
240
+ cleaned[key] = {k: self.clean_google_schema(v) for k, v in value.items()}
241
+ elif key == 'items':
242
+ cleaned[key] = self.clean_google_schema(value)
243
+ else:
244
+ cleaned[key] = value
245
+
246
+ # ... [Rest of your existing type conversion logic stays the same] ...
247
+ if 'type' in cleaned:
248
+ if cleaned['type'] == 'integer':
249
+ cleaned['type'] = 'number' # Google prefers 'number' over 'integer'
250
+ elif cleaned['type'] == 'object' and 'properties' not in cleaned:
251
+ # Ensure objects have properties field, even if empty, to prevent confusion
252
+ cleaned['properties'] = {}
253
+ elif isinstance(cleaned['type'], list):
254
+ non_null_types = [t for t in cleaned['type'] if t != 'null']
255
+ cleaned['type'] = non_null_types[0] if non_null_types else 'string'
256
+
257
+ # Handle anyOf (union types) - Simplified for Gemini
258
+ if 'anyOf' in schema:
259
+ for option in schema['anyOf']:
260
+ if not isinstance(option, dict): continue
261
+ option_type = option.get('type')
262
+ if option_type and option_type != 'null':
263
+ cleaned['type'] = option_type
264
+ if option_type == 'array' and 'items' in option:
265
+ cleaned['items'] = self.clean_google_schema(option['items'])
266
+ if option_type == 'object' and 'properties' in option:
267
+ cleaned['properties'] = {k: self.clean_google_schema(v) for k, v in option['properties'].items()}
268
+ if 'required' in option:
269
+ cleaned['required'] = option['required']
270
+ break
271
+ if 'type' not in cleaned:
272
+ cleaned['type'] = 'string'
273
+
274
+ # Ensure object-like schemas always advertise an object type
275
+ if 'properties' in cleaned and cleaned.get('type') != 'object':
276
+ cleaned['type'] = 'object'
277
+
278
+ # Remove problematic fields
279
+ problematic_fields = {
280
+ 'prefixItems', 'additionalItems', 'minItems', 'maxItems',
281
+ 'minLength', 'maxLength', 'pattern', 'format', 'minimum',
282
+ 'maximum', 'exclusiveMinimum', 'exclusiveMaximum', 'multipleOf',
283
+ 'allOf', 'anyOf', 'oneOf', 'not', 'const', 'examples',
284
+ '$defs', 'definitions', '$ref', 'title', 'additionalProperties'
285
+ }
286
+
287
+ for field in problematic_fields:
288
+ cleaned.pop(field, None)
289
+
290
+ return cleaned
291
+
292
+ def _recursive_json_repair(self, data: Any) -> Any:
293
+ """
294
+ Traverses a dictionary/list and attempts to parse string values
295
+ that look like JSON objects/lists.
296
+ """
297
+ if isinstance(data, dict):
298
+ return {k: self._recursive_json_repair(v) for k, v in data.items()}
299
+ elif isinstance(data, list):
300
+ return [self._recursive_json_repair(item) for item in data]
301
+ elif isinstance(data, str):
302
+ data = data.strip()
303
+ # fast check if it looks like json
304
+ if (data.startswith('{') and data.endswith('}')) or \
305
+ (data.startswith('[') and data.endswith(']')):
306
+ try:
307
+ import json
308
+ parsed = json.loads(data)
309
+ # Recurse into the parsed object in case it has nested strings
310
+ return self._recursive_json_repair(parsed)
311
+ except (json.JSONDecodeError, TypeError):
312
+ return data
313
+ return data
314
+
315
+ def _apply_structured_output_schema(
316
+ self,
317
+ generation_config: Dict[str, Any],
318
+ output_config: Optional[StructuredOutputConfig]
319
+ ) -> Optional[Dict[str, Any]]:
320
+ """Apply a cleaned structured output schema to the generationho config."""
321
+ if not output_config or output_config.format != OutputFormat.JSON:
322
+ return None
323
+
324
+ try:
325
+ raw_schema = output_config.get_schema()
326
+ cleaned_schema = self.clean_google_schema(raw_schema)
327
+ fixed_schema = self._fix_tool_schema(cleaned_schema)
328
+ except Exception as exc:
329
+ self.logger.error(
330
+ f"Failed to generate structured output schema for Gemini: {exc}"
331
+ )
332
+ return None
333
+
334
+ generation_config["response_mime_type"] = "application/json"
335
+ generation_config["response_schema"] = fixed_schema
336
+ return fixed_schema
337
+
338
+ def _build_tools(self, tool_type: str, filter_names: Optional[List[str]] = None) -> Optional[List[types.Tool]]:
339
+ """Build tools based on the specified type."""
340
+ if tool_type == "custom_functions":
341
+ # migrate to use abstractool + tool definition:
342
+ # Group function declarations by their category
343
+ declarations_by_category = defaultdict(list)
344
+ for tool in self.tool_manager.all_tools():
345
+ tool_name = tool.name
346
+ if filter_names is not None and tool_name not in filter_names:
347
+ continue
348
+
349
+ tool_name = tool.name
350
+ category = getattr(tool, 'category', 'tools')
351
+ if isinstance(tool, AbstractTool):
352
+ full_schema = tool.get_tool_schema()
353
+ tool_description = full_schema.get("description", tool.description)
354
+ # Extract ONLY the parameters part
355
+ schema = full_schema.get("parameters", {}).copy()
356
+ # Clean the schema for Google compatibility
357
+ schema = self.clean_google_schema(schema)
358
+ elif isinstance(tool, ToolDefinition):
359
+ tool_description = tool.description
360
+ schema = self.clean_google_schema(tool.input_schema.copy())
361
+ else:
362
+ # Fallback for other tool types
363
+ tool_description = getattr(tool, 'description', f"Tool: {tool_name}")
364
+ schema = getattr(tool, 'input_schema', {})
365
+ schema = self.clean_google_schema(schema)
366
+
367
+ # Ensure we have a valid parameters schema
368
+ if not schema:
369
+ schema = {
370
+ "type": "object",
371
+ "properties": {},
372
+ "required": []
373
+ }
374
+ try:
375
+ declaration = types.FunctionDeclaration(
376
+ name=tool_name,
377
+ description=tool_description,
378
+ parameters=self._fix_tool_schema(schema)
379
+ )
380
+ declarations_by_category[category].append(declaration)
381
+ except Exception as e:
382
+ self.logger.error(f"Error creating function declaration for {tool_name}: {e}")
383
+ # Skip this tool if it can't be created
384
+ continue
385
+
386
+ tool_list = []
387
+ for category, declarations in declarations_by_category.items():
388
+ if declarations:
389
+ tool_list.append(
390
+ types.Tool(
391
+ function_declarations=declarations
392
+ )
393
+ )
394
+ return tool_list
395
+ elif tool_type == "builtin_tools":
396
+ return [
397
+ types.Tool(
398
+ google_search=types.GoogleSearch()
399
+ ),
400
+ ]
401
+
402
+ return None
403
+
404
+ def _extract_function_calls(self, response) -> List:
405
+ """Extract function calls from response - handles both proper function calls AND code blocks."""
406
+ function_calls = []
407
+
408
+ try:
409
+ if (response.candidates and
410
+ len(response.candidates) > 0 and
411
+ response.candidates[0].content and
412
+ response.candidates[0].content.parts):
413
+
414
+ for part in response.candidates[0].content.parts:
415
+ # First, check for proper function calls
416
+ if hasattr(part, 'function_call') and part.function_call:
417
+ function_calls.append(part.function_call)
418
+ self.logger.debug(f"Found proper function call: {part.function_call.name}")
419
+
420
+ # Second, check for text that contains tool code blocks
421
+ elif hasattr(part, 'text') and part.text and '```tool_code' in part.text:
422
+ self.logger.info("Found tool code block - parsing as function call")
423
+ code_block_calls = self._parse_tool_code_blocks(part.text)
424
+ function_calls.extend(code_block_calls)
425
+
426
+ except (AttributeError, IndexError) as e:
427
+ self.logger.debug(f"Error extracting function calls: {e}")
428
+
429
+ self.logger.debug(f"Total function calls extracted: {len(function_calls)}")
430
+ return function_calls
431
+
432
+ async def _handle_stateless_function_calls(
433
+ self,
434
+ response,
435
+ model: str,
436
+ contents: List,
437
+ config,
438
+ all_tool_calls: List[ToolCall],
439
+ original_prompt: Optional[str] = None
440
+ ) -> Any:
441
+ """Handle function calls in stateless mode (single request-response)."""
442
+ function_calls = self._extract_function_calls(response)
443
+
444
+ if not function_calls:
445
+ return response
446
+
447
+ # Execute function calls
448
+ tool_call_objects = []
449
+ for fc in function_calls:
450
+ tc = ToolCall(
451
+ id=f"call_{uuid.uuid4().hex[:8]}",
452
+ name=fc.name,
453
+ arguments=dict(fc.args)
454
+ )
455
+ tool_call_objects.append(tc)
456
+
457
+ start_time = time.time()
458
+ tool_execution_tasks = [
459
+ self._execute_tool(fc.name, dict(fc.args)) for fc in function_calls
460
+ ]
461
+ tool_results = await asyncio.gather(
462
+ *tool_execution_tasks,
463
+ return_exceptions=True
464
+ )
465
+ execution_time = time.time() - start_time
466
+
467
+ for tc, result in zip(tool_call_objects, tool_results):
468
+ tc.execution_time = execution_time / len(tool_call_objects)
469
+ if isinstance(result, Exception):
470
+ tc.error = str(result)
471
+ else:
472
+ tc.result = result
473
+
474
+ all_tool_calls.extend(tool_call_objects)
475
+
476
+ # Prepare function responses
477
+ function_response_parts = []
478
+ for fc, result in zip(function_calls, tool_results):
479
+ if isinstance(result, Exception):
480
+ response_content = f"Error: {str(result)}"
481
+ else:
482
+ response_content = str(result.get('result', result) if isinstance(result, dict) else result)
483
+
484
+ function_response_parts.append(
485
+ Part(
486
+ function_response=types.FunctionResponse(
487
+ name=fc.name,
488
+ response={"result": response_content}
489
+ )
490
+ )
491
+ )
492
+
493
+ if summary_part := self._create_tool_summary_part(
494
+ function_calls,
495
+ tool_results,
496
+ original_prompt
497
+ ):
498
+ function_response_parts.append(summary_part)
499
+
500
+ # Add function call and responses to conversation
501
+ contents.append({
502
+ "role": "model",
503
+ "parts": [{"function_call": fc} for fc in function_calls]
504
+ })
505
+ contents.append({
506
+ "role": "user",
507
+ "parts": function_response_parts
508
+ })
509
+
510
+ # Generate final response
511
+ final_response = await self.client.aio.models.generate_content(
512
+ model=model,
513
+ contents=contents,
514
+ config=config
515
+ )
516
+
517
+ return final_response
518
+
519
+ def _process_tool_result_for_api(self, result) -> dict:
520
+ """
521
+ Process tool result for Google Function Calling API compatibility.
522
+ This method serializes various Python objects into a JSON-compatible
523
+ dictionary for the Google GenAI API.
524
+ """
525
+ # 1. Handle exceptions and special wrapper types first
526
+ if isinstance(result, Exception):
527
+ return {"result": f"Tool execution failed: {str(result)}", "error": True}
528
+
529
+ # Handle ToolResult wrapper
530
+ if isinstance(result, ToolResult):
531
+ content = result.result
532
+ if result.metadata and 'stdout' in result.metadata:
533
+ # Prioritize stdout if exists
534
+ content = result.metadata['stdout']
535
+ content = result.metadata['stdout']
536
+ result = content # The actual result to process is the content
537
+
538
+ # Handle string results early (no conversion needed)
539
+ if isinstance(result, str):
540
+ if not result.strip():
541
+ return {"result": "Code executed successfully (no output)"}
542
+ return {"result": result}
543
+
544
+ # Convert complex types to basic Python types
545
+ clean_result = result
546
+
547
+ if isinstance(result, pd.DataFrame):
548
+ # Convert DataFrame to records and ensure all keys are strings
549
+ # This handles DataFrames with integer or other non-string column names
550
+ records = result.to_dict(orient='records')
551
+ clean_result = [
552
+ {str(k): v for k, v in record.items()}
553
+ for record in records
554
+ ]
555
+ elif isinstance(result, list):
556
+ # Handle lists (including lists of Pydantic models)
557
+ clean_result = []
558
+ for item in result:
559
+ if hasattr(item, 'model_dump'): # Pydantic v2
560
+ clean_result.append(item.model_dump())
561
+ elif hasattr(item, 'dict'): # Pydantic v1
562
+ clean_result.append(item.dict())
563
+ else:
564
+ clean_result.append(item)
565
+ elif hasattr(result, 'model_dump'): # Pydantic v2 single model
566
+ clean_result = result.model_dump()
567
+ elif hasattr(result, 'dict'): # Pydantic v1 single model
568
+ clean_result = result.dict()
569
+
570
+ # 4. Attempt to serialize the processed result
571
+ try:
572
+ serialized = self._json.dumps(clean_result)
573
+ json_compatible_result = self._json.loads(serialized)
574
+ except Exception as e:
575
+ # This is the fallback for non-serializable objects (like PriceOutput)
576
+ self.logger.warning(
577
+ f"Could not serialize result of type {type(clean_result)} to JSON: {e}. "
578
+ "Falling back to string representation."
579
+ )
580
+ json_compatible_result = str(clean_result)
581
+
582
+
583
+ # Wrap for Google Function Calling format
584
+ if isinstance(json_compatible_result, dict) and 'result' in json_compatible_result:
585
+ return json_compatible_result
586
+ else:
587
+ return {"result": json_compatible_result}
588
+
589
+ def _summarize_tool_result(self, result: Any, max_length: int = 1200) -> str:
590
+ """Create a short, human-readable summary of a tool result."""
591
+
592
+ try:
593
+ if isinstance(result, Exception):
594
+ summary = f"Error: {result}"
595
+ elif isinstance(result, pd.DataFrame):
596
+ preview = result.head(5)
597
+ summary = preview.to_string(index=True)
598
+ elif hasattr(result, 'model_dump'):
599
+ summary = self._json.dumps(result.model_dump())
600
+ elif isinstance(result, (dict, list)):
601
+ summary = self._json.dumps(result)
602
+ else:
603
+ summary = str(result)
604
+ except Exception as exc: # pylint: disable=broad-except
605
+ summary = f"Unable to summarize result: {exc}"
606
+
607
+ summary = summary.strip() or "[empty result]"
608
+ if len(summary) > max_length:
609
+ summary = summary[:max_length].rstrip() + "…"
610
+ return summary
611
+
612
+ def _create_tool_summary_part(
613
+ self,
614
+ function_calls,
615
+ tool_results,
616
+ original_prompt: Optional[str] = None
617
+ ) -> Optional[Part]:
618
+ """Build a textual summary of tool outputs for the model to read easily."""
619
+
620
+ if not function_calls or not tool_results:
621
+ return None
622
+
623
+ summary_lines = ["Tool execution summaries:"]
624
+ for fc, result in zip(function_calls, tool_results):
625
+ summary_lines.append(
626
+ f"- {fc.name}: {self._summarize_tool_result(result)}"
627
+ )
628
+
629
+ if original_prompt:
630
+ summary_lines.append(f"Original Request: {original_prompt}")
631
+
632
+ summary_lines.append(
633
+ "Use the information above to craft the final response without running redundant tool calls."
634
+ )
635
+
636
+ summary_text = "\n".join(summary_lines)
637
+ return Part(text=summary_text)
638
+
639
+ async def _handle_multiturn_function_calls(
640
+ self,
641
+ chat,
642
+ initial_response,
643
+ all_tool_calls: List[ToolCall],
644
+ original_prompt: Optional[str] = None,
645
+ model: str = None,
646
+ max_iterations: int = 10,
647
+ config: GenerateContentConfig = None,
648
+ max_retries: int = 1,
649
+ lazy_loading: bool = False,
650
+ active_tool_names: Optional[set] = None,
651
+ ) -> Any:
652
+ """
653
+ Simple multi-turn function calling - just keep going until no more function calls.
654
+ """
655
+ current_response = initial_response
656
+ current_config = config
657
+ iteration = 0
658
+
659
+ if active_tool_names is None:
660
+ active_tool_names = set()
661
+
662
+ model = model or self.model
663
+ self.logger.info("Starting simple multi-turn function calling loop")
664
+
665
+ while iteration < max_iterations:
666
+ iteration += 1
667
+
668
+ # Get function calls (including converted from tool_code)
669
+ function_calls = self._get_function_calls_from_response(current_response)
670
+ if not function_calls:
671
+ # Check if we have any text content in the response
672
+ final_text = self._safe_extract_text(current_response)
673
+ self.logger.notice(f"🎯 Final Response from Gemini: {final_text[:200]}...")
674
+ if not final_text and all_tool_calls:
675
+ self.logger.warning(
676
+ "Final response is empty after tool execution, generating summary..."
677
+ )
678
+ try:
679
+ synthesis_prompt = """
680
+ Please now generate the complete response based on all the information gathered from the tools.
681
+ Provide a comprehensive answer to the original request.
682
+ Synthesize the data and provide insights, analysis, and conclusions as appropriate.
683
+ """
684
+ current_response = await chat.send_message(
685
+ synthesis_prompt,
686
+ config=current_config
687
+ )
688
+ # Check if this worked
689
+ synthesis_text = self._safe_extract_text(current_response)
690
+ if synthesis_text:
691
+ self.logger.info("Successfully generated synthesis response")
692
+ else:
693
+ self.logger.warning("Synthesis attempt also returned empty response")
694
+ except Exception as e:
695
+ self.logger.error(f"Synthesis attempt failed: {e}")
696
+
697
+ self.logger.info(
698
+ f"No function calls found - completed after {iteration-1} iterations"
699
+ )
700
+ break
701
+
702
+ self.logger.info(
703
+ f"Iteration {iteration}: Processing {len(function_calls)} function calls"
704
+ )
705
+
706
+ # Execute function calls
707
+ tool_call_objects = []
708
+ for fc in function_calls:
709
+ tc = ToolCall(
710
+ id=f"call_{uuid.uuid4().hex[:8]}",
711
+ name=fc.name,
712
+ arguments=dict(fc.args) if hasattr(fc.args, 'items') else fc.args
713
+ )
714
+ tool_call_objects.append(tc)
715
+
716
+ # Execute tools
717
+ start_time = time.time()
718
+ tool_execution_tasks = [
719
+ self._execute_tool(fc.name, dict(fc.args) if hasattr(fc.args, 'items') else fc.args)
720
+ for fc in function_calls
721
+ ]
722
+ tool_results = await asyncio.gather(*tool_execution_tasks, return_exceptions=True)
723
+ execution_time = time.time() - start_time
724
+
725
+ # Lazy Loading Check
726
+ if lazy_loading:
727
+ found_new = False
728
+ for fc, result in zip(function_calls, tool_results):
729
+ if fc.name == "search_tools" and isinstance(result, str):
730
+ new_tools = self._check_new_tools(fc.name, result)
731
+ for nt in new_tools:
732
+ if nt not in active_tool_names:
733
+ active_tool_names.add(nt)
734
+ found_new = True
735
+
736
+ if found_new:
737
+ # Rebuild tools with expanded set
738
+ new_tools_list = self._build_tools("custom_functions", filter_names=list(active_tool_names))
739
+ current_config.tools = new_tools_list
740
+ self.logger.info(f"Updated tools for next turn. Count: {len(active_tool_names)}")
741
+
742
+ # Update tool call objects
743
+ for tc, result in zip(tool_call_objects, tool_results):
744
+ tc.execution_time = execution_time / len(tool_call_objects)
745
+ if isinstance(result, Exception):
746
+ tc.error = str(result)
747
+ self.logger.error(f"Tool {tc.name} failed: {result}")
748
+ else:
749
+ tc.result = result
750
+ # self.logger.info(f"Tool {tc.name} result: {result}")
751
+
752
+ all_tool_calls.extend(tool_call_objects)
753
+ function_response_parts = []
754
+ for fc, result in zip(function_calls, tool_results):
755
+ tool_id = fc.id or f"call_{uuid.uuid4().hex[:8]}"
756
+ self.logger.notice(f"🔍 Tool: {fc.name}")
757
+ self.logger.notice(f"📤 Raw Result Type: {type(result)}")
758
+
759
+ try:
760
+ response_content = self._process_tool_result_for_api(result)
761
+ self.logger.info(f"📦 Processed for API: {response_content}")
762
+
763
+ function_response_parts.append(
764
+ Part(
765
+ function_response=types.FunctionResponse(
766
+ id=tool_id,
767
+ name=fc.name,
768
+ response=response_content
769
+ )
770
+ )
771
+ )
772
+
773
+ except Exception as e:
774
+ self.logger.error(f"Error processing result for tool {fc.name}: {e}")
775
+ function_response_parts.append(
776
+ Part(
777
+ function_response=types.FunctionResponse(
778
+ id=tool_id,
779
+ name=fc.name,
780
+ response={"result": f"Tool error: {str(e)}", "error": True}
781
+ )
782
+ )
783
+ )
784
+
785
+ summary_part = self._create_tool_summary_part(
786
+ function_calls,
787
+ tool_results,
788
+ original_prompt
789
+ )
790
+ # Combine the tool results with the textual summary prompt
791
+ next_prompt_parts = function_response_parts.copy()
792
+ if summary_part:
793
+ next_prompt_parts.append(summary_part)
794
+
795
+ # Send responses back
796
+ retry_count = 0
797
+ try:
798
+ self.logger.debug(
799
+ f"Sending {len(next_prompt_parts)} responses back to model"
800
+ )
801
+ while retry_count < max_retries:
802
+ try:
803
+ current_response = await chat.send_message(
804
+ next_prompt_parts,
805
+ config=current_config
806
+ )
807
+ finish_reason = getattr(current_response.candidates[0], 'finish_reason', None)
808
+ if finish_reason:
809
+ if finish_reason.name == "MAX_TOKENS" and current_config.max_output_tokens < 8192:
810
+ self.logger.warning(
811
+ f"Hit MAX_TOKENS limit. Retrying with increased token limit."
812
+ )
813
+ retry_count += 1
814
+ current_config.max_output_tokens = 8192
815
+ continue
816
+ elif finish_reason.name == "MALFORMED_FUNCTION_CALL":
817
+ self.logger.warning(
818
+ f"Malformed function call detected. Retrying..."
819
+ )
820
+ retry_count += 1
821
+ await asyncio.sleep(2 ** retry_count)
822
+ continue
823
+ break
824
+ except Exception as e:
825
+ self.logger.error(f"Error sending message: {e}")
826
+ retry_count += 1
827
+ await asyncio.sleep(2 ** retry_count) # Exponential backoff
828
+ if (retry_count + 1) >= max_retries:
829
+ self.logger.error("Max retries reached, aborting")
830
+ raise e
831
+
832
+ # Check for UNEXPECTED_TOOL_CALL error
833
+ if (hasattr(current_response, 'candidates') and
834
+ current_response.candidates and
835
+ hasattr(current_response.candidates[0], 'finish_reason')):
836
+
837
+ finish_reason = current_response.candidates[0].finish_reason
838
+
839
+ if str(finish_reason) == 'FinishReason.UNEXPECTED_TOOL_CALL':
840
+ self.logger.warning("Received UNEXPECTED_TOOL_CALL")
841
+
842
+ # Debug what we got back
843
+ if hasattr(current_response, 'text'):
844
+ try:
845
+ preview = current_response.text[:100] if current_response.text else "No text"
846
+ self.logger.debug(f"Response preview: {preview}")
847
+ except:
848
+ self.logger.debug("Could not preview response text")
849
+
850
+ except Exception as e:
851
+ self.logger.error(f"Failed to send responses back: {e}")
852
+ break
853
+
854
+ self.logger.info(f"Completed with {len(all_tool_calls)} total tool calls")
855
+ return current_response
856
+
857
+ def _parse_tool_code_blocks(self, text: str) -> List:
858
+ """Convert tool_code blocks to function call objects."""
859
+ function_calls = []
860
+
861
+ if '```tool_code' not in text:
862
+ return function_calls
863
+
864
+ # Simple regex to extract tool calls
865
+ pattern = r'```tool_code\s*\n\s*print\(default_api\.(\w+)\((.*?)\)\)\s*\n\s*```'
866
+ matches = re.findall(pattern, text, re.DOTALL)
867
+
868
+ for tool_name, args_str in matches:
869
+ self.logger.debug(f"Converting tool_code to function call: {tool_name}")
870
+ try:
871
+ # Parse arguments like: a = 9310, b = 3, operation = "divide"
872
+ args = {}
873
+ for arg_part in args_str.split(','):
874
+ if '=' in arg_part:
875
+ key, value = arg_part.split('=', 1)
876
+ key = key.strip()
877
+ value = value.strip().strip('"\'') # Remove quotes
878
+
879
+ # Try to convert to number
880
+ try:
881
+ if '.' in value:
882
+ args[key] = float(value)
883
+ else:
884
+ args[key] = int(value)
885
+ except ValueError:
886
+ args[key] = value # Keep as string
887
+ # extract tool from Tool Manager
888
+ tool = self.tool_manager.get_tool(tool_name)
889
+ if tool:
890
+ # Create function call
891
+ fc = types.FunctionCall(
892
+ id=f"call_{uuid.uuid4().hex[:8]}",
893
+ name=tool_name,
894
+ args=args
895
+ )
896
+ function_calls.append(fc)
897
+ self.logger.info(f"Created function call: {tool_name}({args})")
898
+
899
+ except Exception as e:
900
+ self.logger.error(f"Failed to parse tool_code: {e}")
901
+
902
+ return function_calls
903
+
904
+ def _get_function_calls_from_response(self, response) -> List:
905
+ """Get function calls from response - handles both proper calls and tool_code blocks."""
906
+ function_calls = []
907
+
908
+ try:
909
+ if (response.candidates and
910
+ response.candidates[0].content and
911
+ response.candidates[0].content.parts):
912
+
913
+ for part in response.candidates[0].content.parts:
914
+ # Check for proper function calls first
915
+ if hasattr(part, 'function_call') and part.function_call:
916
+ function_calls.append(part.function_call)
917
+ self.logger.debug(
918
+ f"Found proper function call: {part.function_call.name}"
919
+ )
920
+
921
+ # Check for tool_code in text parts
922
+ elif hasattr(part, 'text') and part.text and '```tool_code' in part.text:
923
+ self.logger.info("Found tool_code block - converting to function call")
924
+ code_function_calls = self._parse_tool_code_blocks(part.text)
925
+ function_calls.extend(code_function_calls)
926
+
927
+ except Exception as e:
928
+ self.logger.error(f"Error getting function calls: {e}")
929
+
930
+ self.logger.info(f"Total function calls found: {len(function_calls)}")
931
+ return function_calls
932
+
933
+ def _safe_extract_text(self, response) -> str:
934
+ """
935
+ Enhanced text extraction that handles reasoning models and mixed content warnings.
936
+
937
+ This method tries multiple approaches to extract text from Google GenAI responses,
938
+ handling special cases like thought_signature parts from reasoning models.
939
+ """
940
+
941
+ # Pre-check for function calls to avoid library warnings when accessing .text
942
+ has_function_call = False
943
+ try:
944
+ if (hasattr(response, 'candidates') and response.candidates and
945
+ len(response.candidates) > 0 and hasattr(response.candidates[0], 'content') and
946
+ response.candidates[0].content and hasattr(response.candidates[0].content, 'parts') and
947
+ response.candidates[0].content.parts):
948
+ for part in response.candidates[0].content.parts:
949
+ if hasattr(part, 'function_call') and part.function_call:
950
+ has_function_call = True
951
+ break
952
+ except Exception:
953
+ pass
954
+
955
+ # Method 1: Try response.text first (fastest path)
956
+ # Skip if we found a function call, as accessing .text triggers a warning in the library
957
+ if not has_function_call:
958
+ try:
959
+ if hasattr(response, 'text') and response.text:
960
+ if (text := response.text.strip()):
961
+ self.logger.debug(
962
+ f"Extracted text via response.text: '{text[:100]}...'"
963
+ )
964
+ return text
965
+ except Exception as e:
966
+ # This is expected with reasoning models that have mixed content
967
+ self.logger.debug(
968
+ f"response.text failed (normal for reasoning models): {e}"
969
+ )
970
+
971
+ # Method 2: Manual extraction from parts (more robust)
972
+ try:
973
+ if (hasattr(response, 'candidates') and response.candidates and len(response.candidates) > 0 and
974
+ hasattr(response.candidates[0], 'content') and response.candidates[0].content and
975
+ hasattr(response.candidates[0].content, 'parts') and response.candidates[0].content.parts):
976
+
977
+ text_parts = []
978
+ thought_parts_found = 0
979
+
980
+ # Extract text from each part, handling special cases
981
+ for part in response.candidates[0].content.parts:
982
+ # Check for regular text content
983
+ if hasattr(part, 'text') and part.text:
984
+ if (clean_text := part.text.strip()):
985
+ text_parts.append(clean_text)
986
+ self.logger.debug(
987
+ f"Found text part: '{clean_text[:50]}...'"
988
+ )
989
+
990
+ # Log non-text parts but don't extract them
991
+ elif hasattr(part, 'thought_signature'):
992
+ thought_parts_found += 1
993
+ self.logger.debug(
994
+ "Found thought_signature part (reasoning model internal thought)"
995
+ )
996
+
997
+ # Log reasoning model detection
998
+ if thought_parts_found > 0:
999
+ self.logger.debug(
1000
+ f"Detected reasoning model with {thought_parts_found} thought parts"
1001
+ )
1002
+
1003
+ # Combine text parts
1004
+ if text_parts:
1005
+ if (combined_text := "".join(text_parts).strip()):
1006
+ self.logger.debug(
1007
+ f"Successfully extracted text from {len(text_parts)} parts"
1008
+ )
1009
+ return combined_text
1010
+ else:
1011
+ self.logger.debug("No text parts found in response parts")
1012
+
1013
+ except Exception as e:
1014
+ self.logger.error(f"Manual text extraction failed: {e}")
1015
+
1016
+ # Method 3: Deep inspection for debugging (fallback)
1017
+ try:
1018
+ if hasattr(response, 'candidates') and response.candidates:
1019
+ candidate = response.candidates[0] if len(response.candidates) > 0 else None
1020
+ if candidate:
1021
+ if hasattr(candidate, 'finish_reason'):
1022
+ finish_reason = str(candidate.finish_reason)
1023
+ self.logger.debug(f"Response finish reason: {finish_reason}")
1024
+ if 'MAX_TOKENS' in finish_reason:
1025
+ self.logger.warning("Response truncated due to token limit")
1026
+ elif 'SAFETY' in finish_reason:
1027
+ self.logger.warning("Response blocked by safety filters")
1028
+ elif 'STOP' in finish_reason:
1029
+ self.logger.debug("Response completed normally but no text found")
1030
+
1031
+ if hasattr(candidate, 'content') and candidate.content:
1032
+ if hasattr(candidate.content, 'parts'):
1033
+ parts_count = len(candidate.content.parts) if candidate.content.parts else 0
1034
+ self.logger.debug(f"Response has {parts_count} parts but no extractable text")
1035
+ if candidate.content.parts:
1036
+ part_types = []
1037
+ for part in candidate.content.parts:
1038
+ part_attrs = [attr for attr in dir(part)
1039
+ if not attr.startswith('_') and hasattr(part, attr) and getattr(part, attr)]
1040
+ part_types.append(part_attrs)
1041
+ self.logger.debug(f"Part attribute types found: {part_types}")
1042
+
1043
+ except Exception as e:
1044
+ self.logger.error(f"Deep inspection failed: {e}")
1045
+
1046
+ # Method 4: Final fallback - return empty string with clear logging
1047
+ self.logger.warning(
1048
+ "Could not extract any text from response using any method"
1049
+ )
1050
+ return ""
1051
+
1052
+ async def ask(
1053
+ self,
1054
+ prompt: str,
1055
+ model: Union[str, GoogleModel] = None,
1056
+ max_tokens: Optional[int] = None,
1057
+ temperature: Optional[float] = None,
1058
+ files: Optional[List[Union[str, Path]]] = None,
1059
+ system_prompt: Optional[str] = None,
1060
+ structured_output: Union[type, StructuredOutputConfig] = None,
1061
+ user_id: Optional[str] = None,
1062
+ session_id: Optional[str] = None,
1063
+ tools: Optional[List[Dict[str, Any]]] = None,
1064
+ use_tools: Optional[bool] = None,
1065
+ stateless: bool = False,
1066
+ deep_research: bool = False,
1067
+ background: bool = False,
1068
+ file_search_store_names: Optional[List[str]] = None,
1069
+ lazy_loading: bool = False,
1070
+ **kwargs
1071
+ ) -> AIMessage:
1072
+ """
1073
+ Ask a question to Google's Generative AI with support for parallel tool calls.
1074
+
1075
+ Args:
1076
+ prompt (str): The input prompt for the model.
1077
+ model (Union[str, GoogleModel]): The model to use. If None, uses the client's configured model
1078
+ or defaults to GEMINI_2_5_FLASH.
1079
+ max_tokens (int): Maximum number of tokens in the response.
1080
+ temperature (float): Sampling temperature for response generation.
1081
+ files (Optional[List[Union[str, Path]]]): Optional files to include in the request.
1082
+ system_prompt (Optional[str]): Optional system prompt to guide the model.
1083
+ structured_output (Union[type, StructuredOutputConfig]): Optional structured output configuration.
1084
+ user_id (Optional[str]): Optional user identifier for tracking.
1085
+ session_id: Optional session identifier for tracking.
1086
+ force_tool_usage (Optional[str]): Force usage of specific tools, if needed.
1087
+ ("custom_functions", "builtin_tools", or None)
1088
+ stateless (bool): If True, don't use conversation memory (stateless mode).
1089
+ deep_research (bool): If True, use Google's deep research agent.
1090
+ background (bool): If True, execute deep research in background mode.
1091
+ file_search_store_names (Optional[List[str]]): Names of file search stores for deep research.
1092
+ """
1093
+ max_retries = kwargs.pop('max_retries', 1)
1094
+
1095
+ # Route to deep research if requested
1096
+ if deep_research:
1097
+ self.logger.info("Using Google Deep Research mode via interactions.create()")
1098
+ return await self._deep_research_ask(
1099
+ prompt=prompt,
1100
+ background=background,
1101
+ file_search_store_names=file_search_store_names,
1102
+ user_id=user_id,
1103
+ session_id=session_id
1104
+ )
1105
+
1106
+ model = model.value if isinstance(model, GoogleModel) else model
1107
+ # If use_tools is None, use the instance default
1108
+ _use_tools = use_tools if use_tools is not None else self.enable_tools
1109
+ if not model:
1110
+ model = self.model or GoogleModel.GEMINI_2_5_FLASH.value
1111
+
1112
+ # Handle case where model is passed as a tuple or list
1113
+ if isinstance(model, (list, tuple)):
1114
+ model = model[0]
1115
+
1116
+ # Generate unique turn ID for tracking
1117
+ turn_id = str(uuid.uuid4())
1118
+ original_prompt = prompt
1119
+
1120
+ # Prepare conversation context using unified memory system
1121
+ conversation_history = None
1122
+ messages = []
1123
+
1124
+ # Use the abstract method to prepare conversation context
1125
+ if stateless:
1126
+ # For stateless mode, skip conversation memory
1127
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
1128
+ conversation_history = None
1129
+ else:
1130
+ # Use the unified conversation context preparation from AbstractClient
1131
+ messages, conversation_history, system_prompt = await self._prepare_conversation_context(
1132
+ prompt, files, user_id, session_id, system_prompt, stateless=stateless
1133
+ )
1134
+
1135
+ # Prepare conversation history for Google GenAI format
1136
+ history = []
1137
+ # Construct history directly from the 'messages' array, which should be in the correct format
1138
+ if messages:
1139
+ for msg in messages[:-1]: # Exclude the current user message (last in list)
1140
+ role = msg['role'].lower()
1141
+ # Assuming content is already in the format [{"type": "text", "text": "..."}]
1142
+ # or other GenAI Part types if files were involved.
1143
+ # Here, we only expect text content for history, as images/files are for the current turn.
1144
+ if role == 'user':
1145
+ # Content can be a list of dicts (for text/parts) or a single string.
1146
+ # Standardize to list of Parts.
1147
+ parts = []
1148
+ for part_content in msg.get('content', []):
1149
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1150
+ parts.append(Part(text=part_content.get('text', '')))
1151
+ # Add other part types if necessary for history (e.g., function responses)
1152
+ if parts:
1153
+ history.append(UserContent(parts=parts))
1154
+ elif role in ['assistant', 'model']:
1155
+ parts = []
1156
+ for part_content in msg.get('content', []):
1157
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1158
+ parts.append(Part(text=part_content.get('text', '')))
1159
+ if parts:
1160
+ history.append(ModelContent(parts=parts))
1161
+
1162
+ default_tokens = max_tokens or self.max_tokens or 4096
1163
+ generation_config = {
1164
+ "max_output_tokens": default_tokens,
1165
+ "temperature": temperature or self.temperature
1166
+ }
1167
+ base_temperature = generation_config["temperature"]
1168
+
1169
+ # Prepare structured output configuration
1170
+ output_config = self._get_structured_config(structured_output)
1171
+
1172
+ # Tool selection
1173
+ requested_tools = tools
1174
+
1175
+ if _use_tools:
1176
+ if requested_tools and isinstance(requested_tools, list):
1177
+ for tool in requested_tools:
1178
+ self.register_tool(tool)
1179
+ tool_type = "custom_functions"
1180
+ # if Tools, reduce temperature to avoid hallucinations.
1181
+ generation_config["temperature"] = 0
1182
+ elif _use_tools is None:
1183
+ # If not explicitly set, analyze the prompt to decide
1184
+ tool_type = self._analyze_prompt_for_tools(prompt)
1185
+ else:
1186
+ tool_type = 'builtin_tools' if _use_tools else None
1187
+
1188
+ tools = self._build_tools(tool_type) if tool_type else []
1189
+
1190
+ if _use_tools and tool_type == "custom_functions" and not tools:
1191
+ self.logger.info(
1192
+ "Tool usage requested but no tools are registered - disabling tools for this request."
1193
+ )
1194
+ _use_tools = False
1195
+ tool_type = None
1196
+ tools = []
1197
+ generation_config["temperature"] = base_temperature
1198
+
1199
+ use_tools = _use_tools
1200
+
1201
+ # LAZY LOADING LOGIC
1202
+ active_tool_names = set()
1203
+ if use_tools and lazy_loading:
1204
+ # Override initial tool selection to just search_tools
1205
+ active_tool_names.add("search_tools")
1206
+ tools = self._build_tools("custom_functions", filter_names=["search_tools"])
1207
+ # Add system prompt instruction
1208
+ search_prompt = "You have access to a library of tools. Use the 'search_tools' function to find relevant tools."
1209
+ system_prompt = f"{system_prompt}\n\n{search_prompt}" if system_prompt else search_prompt
1210
+ # Update final_config later with this new system prompt if needed,
1211
+ # but system_prompt is passed to GenerateContentConfig below.
1212
+
1213
+
1214
+ self.logger.debug(
1215
+ f"Using model: {model}, max_tokens: {default_tokens}, temperature: {temperature}, "
1216
+ f"structured_output: {structured_output}, "
1217
+ f"use_tools: {_use_tools}, tool_type: {tool_type}, toolbox: {len(tools)}, "
1218
+ )
1219
+
1220
+ use_structured_output = bool(output_config)
1221
+ # Google limitation: Cannot combine tools with structured output
1222
+ # Strategy: If both are requested, use tools first, then apply structured output to final result
1223
+ if _use_tools and use_structured_output:
1224
+ self.logger.info(
1225
+ "Google Gemini doesn't support tools + structured output simultaneously. "
1226
+ "Using tools first, then applying structured output to the final result."
1227
+ )
1228
+ structured_output_for_later = output_config
1229
+ # Don't set structured output in initial config
1230
+ output_config = None
1231
+ else:
1232
+ structured_output_for_later = None
1233
+ # Set structured output in generation config if no tools conflict
1234
+ if output_config:
1235
+ self._apply_structured_output_schema(generation_config, output_config)
1236
+
1237
+ # Track tool calls for the response
1238
+ all_tool_calls = []
1239
+ # Build contents for conversation
1240
+ contents = []
1241
+
1242
+ for msg in messages:
1243
+ role = "model" if msg["role"] == "assistant" else msg["role"]
1244
+ if role in ["user", "model"]:
1245
+ text_parts = [part["text"] for part in msg["content"] if "text" in part]
1246
+ if text_parts:
1247
+ contents.append({
1248
+ "role": role,
1249
+ "parts": [{"text": " ".join(text_parts)}]
1250
+ })
1251
+
1252
+ # Add the current prompt
1253
+ contents.append({
1254
+ "role": "user",
1255
+ "parts": [{"text": prompt}]
1256
+ })
1257
+
1258
+ chat = None
1259
+ if not self.client:
1260
+ self.client = await self.get_client()
1261
+ final_config = GenerateContentConfig(
1262
+ system_instruction=system_prompt,
1263
+ safety_settings=[
1264
+ types.SafetySetting(
1265
+ category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
1266
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
1267
+ ),
1268
+ types.SafetySetting(
1269
+ category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
1270
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
1271
+ ),
1272
+ ],
1273
+ tools=tools,
1274
+ **generation_config
1275
+ )
1276
+ if stateless:
1277
+ # For stateless mode, handle in a single call (existing behavior)
1278
+ contents = []
1279
+
1280
+ for msg in messages:
1281
+ role = "model" if msg["role"] == "assistant" else msg["role"]
1282
+ if role in ["user", "model"]:
1283
+ text_parts = [part["text"] for part in msg["content"] if "text" in part]
1284
+ if text_parts:
1285
+ contents.append({
1286
+ "role": role,
1287
+ "parts": [{"text": " ".join(text_parts)}]
1288
+ })
1289
+ try:
1290
+ retry_count = 0
1291
+ while retry_count < max_retries:
1292
+ response = await self.client.aio.models.generate_content(
1293
+ model=model,
1294
+ contents=contents,
1295
+ config=final_config
1296
+ )
1297
+ finish_reason = getattr(response.candidates[0], 'finish_reason', None)
1298
+ if finish_reason and finish_reason.name == "MAX_TOKENS" and generation_config["max_output_tokens"] == 1024:
1299
+ retry_count += 1
1300
+ self.logger.warning(
1301
+ f"Hit MAX_TOKENS limit on stateless response. Retrying {retry_count}/{max_retries} with increased token limit."
1302
+ )
1303
+ final_config.max_output_tokens = 8192
1304
+ continue
1305
+ break
1306
+ except Exception as e:
1307
+ self.logger.error(
1308
+ f"Error during generate_content: {e}"
1309
+ )
1310
+ if (retry_count + 1) >= max_retries:
1311
+ raise e
1312
+ retry_count += 1
1313
+
1314
+ # Handle function calls in stateless mode
1315
+ final_response = await self._handle_stateless_function_calls(
1316
+ response,
1317
+ model,
1318
+ contents,
1319
+ final_config,
1320
+ all_tool_calls,
1321
+ original_prompt=prompt
1322
+ )
1323
+ else:
1324
+ # MULTI-TURN CONVERSATION MODE
1325
+ chat = self.client.aio.chats.create(
1326
+ model=model,
1327
+ history=history
1328
+ )
1329
+ retry_count = 0
1330
+ # Send initial message
1331
+ while retry_count < max_retries:
1332
+ try:
1333
+ response = await chat.send_message(
1334
+ message=prompt,
1335
+ config=final_config
1336
+ )
1337
+ finish_reason = getattr(response.candidates[0], 'finish_reason', None)
1338
+ if finish_reason and finish_reason.name == "MAX_TOKENS" and generation_config["max_output_tokens"] <= 1024:
1339
+ retry_count += 1
1340
+ self.logger.warning(
1341
+ f"Hit MAX_TOKENS limit on initial response. Retrying {retry_count}/{max_retries} with increased token limit."
1342
+ )
1343
+ final_config.max_output_tokens = 8192
1344
+ continue
1345
+ break
1346
+ except Exception as e:
1347
+ # Handle specific network client error (socket/aiohttp issue)
1348
+ if "'NoneType' object has no attribute 'getaddrinfo'" in str(e):
1349
+ self.logger.warning(
1350
+ f"Encountered network client error: {e}. Resetting client and retrying."
1351
+ )
1352
+ # Reset the client
1353
+ self.client = None
1354
+ if not self.client:
1355
+ self.client = await self.get_client()
1356
+ # Recreate the chat session
1357
+ chat = self.client.aio.chats.create(
1358
+ model=model,
1359
+ history=history
1360
+ )
1361
+ retry_count += 1
1362
+ continue
1363
+
1364
+ self.logger.error(
1365
+ f"Error during initial chat.send_message: {e}"
1366
+ )
1367
+ if (retry_count + 1) >= max_retries:
1368
+ raise e
1369
+ retry_count += 1
1370
+
1371
+ has_function_calls = False
1372
+ if response and getattr(response, "candidates", None):
1373
+ candidate = response.candidates[0] if response.candidates else None
1374
+ content = getattr(candidate, "content", None) if candidate else None
1375
+ parts = getattr(content, "parts", None) if content else None
1376
+ has_function_calls = bool(parts)
1377
+
1378
+ self.logger.debug(
1379
+ f"Initial response has function calls: {has_function_calls}"
1380
+ )
1381
+
1382
+ # Multi-turn function calling loop
1383
+ final_response = await self._handle_multiturn_function_calls(
1384
+ chat,
1385
+ response,
1386
+ all_tool_calls,
1387
+ original_prompt=original_prompt,
1388
+ model=model,
1389
+ max_iterations=10,
1390
+ config=final_config,
1391
+ max_retries=max_retries,
1392
+ lazy_loading=lazy_loading,
1393
+ active_tool_names=active_tool_names
1394
+ )
1395
+
1396
+ # Extract assistant response text for conversation memory
1397
+ assistant_response_text = self._safe_extract_text(final_response)
1398
+
1399
+ # If we still don't have text but have tool calls, generate a summary
1400
+ if not assistant_response_text and all_tool_calls:
1401
+ assistant_response_text = self._create_simple_summary(
1402
+ all_tool_calls
1403
+ )
1404
+
1405
+ # Handle structured output
1406
+ final_output = None
1407
+ if structured_output_for_later and use_tools and assistant_response_text:
1408
+ try:
1409
+ # Create a new generation config for structured output only
1410
+ structured_config = {
1411
+ "max_output_tokens": max_tokens or self.max_tokens,
1412
+ "temperature": temperature or self.temperature,
1413
+ "response_mime_type": "application/json"
1414
+ }
1415
+ # Set the schema based on the type of structured output
1416
+ schema_config = (
1417
+ structured_output_for_later
1418
+ if isinstance(structured_output_for_later, StructuredOutputConfig)
1419
+ else self._get_structured_config(structured_output_for_later)
1420
+ )
1421
+ if schema_config:
1422
+ self._apply_structured_output_schema(structured_config, schema_config)
1423
+ # Create a new client call without tools for structured output
1424
+ format_prompt = (
1425
+ f"Please format the following information according to the requested JSON structure. "
1426
+ f"Return only the JSON object with the requested fields:\n\n{assistant_response_text}"
1427
+ )
1428
+ structured_response = await self.client.aio.models.generate_content(
1429
+ model=model,
1430
+ contents=[{"role": "user", "parts": [{"text": format_prompt}]}],
1431
+ config=GenerateContentConfig(**structured_config)
1432
+ )
1433
+ # Extract structured text
1434
+ if structured_text := self._safe_extract_text(structured_response):
1435
+ # Parse the structured output
1436
+ if isinstance(structured_output_for_later, StructuredOutputConfig):
1437
+ final_output = await self._parse_structured_output(
1438
+ structured_text,
1439
+ structured_output_for_later
1440
+ )
1441
+ elif isinstance(structured_output_for_later, type):
1442
+ if hasattr(structured_output_for_later, 'model_validate_json'):
1443
+ final_output = structured_output_for_later.model_validate_json(structured_text)
1444
+ elif hasattr(structured_output_for_later, 'model_validate'):
1445
+ parsed_json = self._json.loads(structured_text)
1446
+ final_output = structured_output_for_later.model_validate(parsed_json)
1447
+ else:
1448
+ final_output = self._json.loads(structured_text)
1449
+ else:
1450
+ final_output = self._json.loads(structured_text)
1451
+ # # --- Fallback Logic ---
1452
+ # is_json_format = (
1453
+ # isinstance(structured_output_for_later, StructuredOutputConfig) and
1454
+ # structured_output_for_later.format == OutputFormat.JSON
1455
+ # )
1456
+ # if is_json_format and isinstance(final_output, str):
1457
+ # try:
1458
+ # self._json.loads(final_output)
1459
+ # except Exception:
1460
+ # self.logger.warning(
1461
+ # "Structured output re-formatting resulted in invalid/truncated JSON. "
1462
+ # "Falling back to original tool output."
1463
+ # )
1464
+ # final_output = assistant_response_text
1465
+ else:
1466
+ self.logger.warning(
1467
+ "No structured text received, falling back to original response"
1468
+ )
1469
+ final_output = assistant_response_text
1470
+ except Exception as e:
1471
+ self.logger.error(f"Error parsing structured output: {e}")
1472
+ # Fallback to original text if structured output fails
1473
+ final_output = assistant_response_text
1474
+ elif output_config and not use_tools:
1475
+ try:
1476
+ final_output = await self._parse_structured_output(
1477
+ assistant_response_text,
1478
+ output_config
1479
+ )
1480
+ except Exception:
1481
+ final_output = assistant_response_text
1482
+ else:
1483
+ final_output = assistant_response_text
1484
+
1485
+ # Update conversation memory with the final response
1486
+ final_assistant_message = {
1487
+ "role": "model",
1488
+ "content": [
1489
+ {
1490
+ "type": "text",
1491
+ "text": str(final_output) if final_output != assistant_response_text else assistant_response_text
1492
+ }
1493
+ ]
1494
+ }
1495
+
1496
+ # Update conversation memory with unified system
1497
+ if not stateless and conversation_history:
1498
+ tools_used = [tc.name for tc in all_tool_calls]
1499
+ await self._update_conversation_memory(
1500
+ user_id,
1501
+ session_id,
1502
+ conversation_history,
1503
+ messages + [final_assistant_message],
1504
+ system_prompt,
1505
+ turn_id,
1506
+ original_prompt,
1507
+ assistant_response_text,
1508
+ tools_used
1509
+ )
1510
+ # Create AIMessage using factory
1511
+ ai_message = AIMessageFactory.from_gemini(
1512
+ response=response,
1513
+ input_text=original_prompt,
1514
+ model=model,
1515
+ user_id=user_id,
1516
+ session_id=session_id,
1517
+ turn_id=turn_id,
1518
+ structured_output=final_output,
1519
+ tool_calls=all_tool_calls,
1520
+ conversation_history=conversation_history,
1521
+ text_response=assistant_response_text
1522
+ )
1523
+
1524
+ # Override provider to distinguish from Vertex AI
1525
+ ai_message.provider = "google_genai"
1526
+
1527
+ return ai_message
1528
+
1529
+ def _create_simple_summary(self, all_tool_calls: List[ToolCall]) -> str:
1530
+ """Create a simple summary from tool calls."""
1531
+ if not all_tool_calls:
1532
+ return "Task completed."
1533
+
1534
+ if len(all_tool_calls) == 1:
1535
+ tc = all_tool_calls[0]
1536
+ if isinstance(tc.result, Exception):
1537
+ return f"Tool {tc.name} failed with error: {tc.result}"
1538
+ elif isinstance(tc.result, pd.DataFrame):
1539
+ if not tc.result.empty:
1540
+ return f"Tool {tc.name} returned a DataFrame with {len(tc.result)} rows."
1541
+ else:
1542
+ return f"Tool {tc.name} returned an empty DataFrame."
1543
+ elif tc.result and isinstance(tc.result, dict) and 'expression' in tc.result:
1544
+ return tc.result['expression']
1545
+ elif tc.result and isinstance(tc.result, dict) and 'result' in tc.result:
1546
+ return f"Result: {tc.result['result']}"
1547
+ else:
1548
+ # Multiple calls - show the final result
1549
+ final_tc = all_tool_calls[-1]
1550
+ if isinstance(final_tc.result, pd.DataFrame):
1551
+ if not final_tc.result.empty:
1552
+ return f"Data: {final_tc.result.to_string()}"
1553
+ else:
1554
+ return f"Final tool {final_tc.name} returned an empty DataFrame."
1555
+ if final_tc.result and isinstance(final_tc.result, dict):
1556
+ if 'result' in final_tc.result:
1557
+ return f"Final result: {final_tc.result['result']}"
1558
+ elif 'expression' in final_tc.result:
1559
+ return final_tc.result['expression']
1560
+
1561
+ return "Calculation completed."
1562
+
1563
+ def _build_function_declarations(self) -> List[types.FunctionDeclaration]:
1564
+ """Build function declarations for Google GenAI tools."""
1565
+ function_declarations = []
1566
+
1567
+ for tool in self.tool_manager.all_tools():
1568
+ tool_name = tool.name
1569
+
1570
+ if isinstance(tool, AbstractTool):
1571
+ full_schema = tool.get_tool_schema()
1572
+ tool_description = full_schema.get("description", tool.description)
1573
+ schema = full_schema.get("parameters", {}).copy()
1574
+ schema = self.clean_google_schema(schema)
1575
+ elif isinstance(tool, ToolDefinition):
1576
+ tool_description = tool.description
1577
+ schema = self.clean_google_schema(tool.input_schema.copy())
1578
+ else:
1579
+ tool_description = getattr(tool, 'description', f"Tool: {tool_name}")
1580
+ schema = getattr(tool, 'input_schema', {})
1581
+ schema = self.clean_google_schema(schema)
1582
+
1583
+ if not schema:
1584
+ schema = {"type": "object", "properties": {}, "required": []}
1585
+
1586
+ try:
1587
+ declaration = types.FunctionDeclaration(
1588
+ name=tool_name,
1589
+ description=tool_description,
1590
+ parameters=self._fix_tool_schema(schema)
1591
+ )
1592
+ function_declarations.append(declaration)
1593
+ except Exception as e:
1594
+ self.logger.error(f"Error creating {tool_name}: {e}")
1595
+ continue
1596
+
1597
+ return function_declarations
1598
+
1599
+ async def ask_stream(
1600
+ self,
1601
+ prompt: str,
1602
+ model: Union[str, GoogleModel] = None,
1603
+ max_tokens: Optional[int] = None,
1604
+ temperature: Optional[float] = None,
1605
+ files: Optional[List[Union[str, Path]]] = None,
1606
+ system_prompt: Optional[str] = None,
1607
+ user_id: Optional[str] = None,
1608
+ session_id: Optional[str] = None,
1609
+ retry_config: Optional[StreamingRetryConfig] = None,
1610
+ on_max_tokens: Optional[str] = "retry", # "retry", "notify", "ignore"
1611
+ tools: Optional[List[Dict[str, Any]]] = None,
1612
+ use_tools: Optional[bool] = None,
1613
+ deep_research: bool = False,
1614
+ agent_config: Optional[Dict[str, Any]] = None,
1615
+ lazy_loading: bool = False,
1616
+ ) -> AsyncIterator[str]:
1617
+ """
1618
+ Stream Google Generative AI's response using AsyncIterator with support for Tool Calling.
1619
+
1620
+ Args:
1621
+ on_max_tokens: How to handle MAX_TOKENS finish reason:
1622
+ - "retry": Automatically retry with increased token limit
1623
+ - "notify": Yield a notification message and continue
1624
+ - "ignore": Silently continue (original behavior)
1625
+ deep_research: If True, use Google's deep research agent (stream mode)
1626
+ agent_config: Optional configuration for deep research (e.g., thinking_summaries)
1627
+ """
1628
+ model = (
1629
+ model.value if isinstance(model, GoogleModel) else model
1630
+ ) or (self.model or GoogleModel.GEMINI_2_5_FLASH.value)
1631
+
1632
+ # Handle case where model is passed as a tuple or list
1633
+ if isinstance(model, (list, tuple)):
1634
+ model = model[0]
1635
+
1636
+ # Stub for deep research streaming
1637
+ if deep_research:
1638
+ self.logger.warning(
1639
+ "Google Deep Research streaming is not yet fully implemented. "
1640
+ "Falling back to standard ask_stream() behavior."
1641
+ )
1642
+ # TODO: Implement interactions.create(stream=True) when SDK supports it
1643
+ # For now, just use regular streaming
1644
+
1645
+ turn_id = str(uuid.uuid4())
1646
+ # Default retry configuration
1647
+ if retry_config is None:
1648
+ retry_config = StreamingRetryConfig()
1649
+
1650
+ # Use the unified conversation context preparation from AbstractClient
1651
+ messages, conversation_history, system_prompt = await self._prepare_conversation_context(
1652
+ prompt, files, user_id, session_id, system_prompt
1653
+ )
1654
+
1655
+ # Prepare conversation history for Google GenAI format
1656
+ history = []
1657
+ if messages:
1658
+ for msg in messages[:-1]: # Exclude the current user message (last in list)
1659
+ role = msg['role'].lower()
1660
+ if role == 'user':
1661
+ parts = []
1662
+ for part_content in msg.get('content', []):
1663
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1664
+ parts.append(Part(text=part_content.get('text', '')))
1665
+ if parts:
1666
+ history.append(UserContent(parts=parts))
1667
+ elif role in ['assistant', 'model']:
1668
+ parts = []
1669
+ for part_content in msg.get('content', []):
1670
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1671
+ parts.append(Part(text=part_content.get('text', '')))
1672
+ if parts:
1673
+ history.append(ModelContent(parts=parts))
1674
+
1675
+ # --- Tool Configuration (Mirrored from ask method) ---
1676
+ _use_tools = use_tools if use_tools is not None else self.enable_tools
1677
+
1678
+ # Register requested tools if any
1679
+ if tools and isinstance(tools, list):
1680
+ for tool in tools:
1681
+ self.register_tool(tool)
1682
+
1683
+ # Determine tool strategy
1684
+ if _use_tools:
1685
+ # If explicit tools passed or just enabled, force low temp
1686
+ temperature = 0 if temperature is None else temperature
1687
+ tool_type = "custom_functions"
1688
+ elif _use_tools is None:
1689
+ # Analyze prompt
1690
+ tool_type = self._analyze_prompt_for_tools(prompt)
1691
+ else:
1692
+ tool_type = 'builtin_tools' if _use_tools else None
1693
+
1694
+ # Build the actual tool objects for Gemini
1695
+ gemini_tools = self._build_tools(tool_type) if tool_type else []
1696
+
1697
+ if _use_tools and tool_type == "custom_functions" and not gemini_tools:
1698
+ # Fallback if no tools registered
1699
+ gemini_tools = None
1700
+
1701
+ # --- Execution Loop ---
1702
+
1703
+ # Retry loop variables
1704
+ current_max_tokens = max_tokens or self.max_tokens
1705
+ retry_count = 0
1706
+
1707
+ # Variables for multi-turn tool loop
1708
+ current_message_content = prompt # Start with the user prompt
1709
+ keep_looping = True
1710
+
1711
+ # Start the chat session once
1712
+ chat = self.client.aio.chats.create(
1713
+ model=model,
1714
+ history=history,
1715
+ config=GenerateContentConfig(
1716
+ system_instruction=system_prompt,
1717
+ tools=gemini_tools,
1718
+ temperature=temperature or self.temperature,
1719
+ max_output_tokens=current_max_tokens
1720
+ )
1721
+ )
1722
+
1723
+ all_assistant_text = [] # Keep track of full text for memory update
1724
+
1725
+ while keep_looping and retry_count <= retry_config.max_retries:
1726
+ # By default, we stop after one turn unless a tool is called
1727
+ keep_looping = False
1728
+
1729
+ try:
1730
+ # If we are retrying due to max tokens, update config
1731
+ chat._config.max_output_tokens = current_max_tokens
1732
+
1733
+ assistant_content_chunk = ""
1734
+ max_tokens_reached = False
1735
+
1736
+ # We need to capture function calls from the chunks as they arrive
1737
+ collected_function_calls = []
1738
+
1739
+ async for chunk in await chat.send_message_stream(current_message_content):
1740
+ # Check for MAX_TOKENS finish reason
1741
+ if (hasattr(chunk, 'candidates') and chunk.candidates and len(chunk.candidates) > 0):
1742
+ candidate = chunk.candidates[0]
1743
+ if (hasattr(candidate, 'finish_reason') and
1744
+ str(candidate.finish_reason) == 'FinishReason.MAX_TOKENS'):
1745
+ max_tokens_reached = True
1746
+
1747
+ if on_max_tokens == "notify":
1748
+ yield f"\n\n⚠️ **Response truncated due to token limit ({current_max_tokens} tokens).**\n"
1749
+ elif on_max_tokens == "retry" and retry_config.auto_retry_on_max_tokens:
1750
+ # Break inner loop to handle retry in outer loop
1751
+ break
1752
+
1753
+ # Capture function calls from the chunk
1754
+ if (hasattr(chunk, 'candidates') and chunk.candidates):
1755
+ for candidate in chunk.candidates:
1756
+ if hasattr(candidate, 'content') and candidate.content and candidate.content.parts:
1757
+ for part in candidate.content.parts:
1758
+ if hasattr(part, 'function_call') and part.function_call:
1759
+ collected_function_calls.append(part.function_call)
1760
+
1761
+ # Yield text content if present
1762
+ if chunk.text:
1763
+ assistant_content_chunk += chunk.text
1764
+ all_assistant_text.append(chunk.text)
1765
+ yield chunk.text
1766
+
1767
+ # --- Handle Max Tokens Retry ---
1768
+ if max_tokens_reached and on_max_tokens == "retry" and retry_config.auto_retry_on_max_tokens:
1769
+ if retry_count < retry_config.max_retries:
1770
+ new_max_tokens = int(current_max_tokens * retry_config.token_increase_factor)
1771
+ yield f"\n\n🔄 **Retrying with increased limit ({new_max_tokens})...**\n\n"
1772
+ current_max_tokens = new_max_tokens
1773
+ retry_count += 1
1774
+ await self._wait_with_backoff(retry_count, retry_config)
1775
+ keep_looping = True # Force loop to continue
1776
+ continue
1777
+ else:
1778
+ yield f"\n\n❌ **Maximum retries reached.**\n"
1779
+
1780
+ # --- Handle Function Calls ---
1781
+ if collected_function_calls:
1782
+ # We have tool calls to execute!
1783
+ self.logger.info(f"Streaming detected {len(collected_function_calls)} tool calls.")
1784
+
1785
+ # Execute tools (parallel)
1786
+ tool_execution_tasks = [
1787
+ self._execute_tool(fc.name, dict(fc.args))
1788
+ for fc in collected_function_calls
1789
+ ]
1790
+ tool_results = await asyncio.gather(*tool_execution_tasks, return_exceptions=True)
1791
+
1792
+ # Build the response parts containing tool outputs
1793
+ function_response_parts = []
1794
+ for fc, result in zip(collected_function_calls, tool_results):
1795
+ response_content = self._process_tool_result_for_api(result)
1796
+ function_response_parts.append(
1797
+ Part(
1798
+ function_response=types.FunctionResponse(
1799
+ name=fc.name,
1800
+ response=response_content
1801
+ )
1802
+ )
1803
+ )
1804
+
1805
+ # Set the next message to be these tool outputs
1806
+ current_message_content = function_response_parts
1807
+
1808
+ # Force the loop to run again to stream the answer based on these tools
1809
+ keep_looping = True
1810
+
1811
+ except Exception as e:
1812
+ # Handle specific network client error
1813
+ if "'NoneType' object has no attribute 'getaddrinfo'" in str(e):
1814
+ if retry_count < retry_config.max_retries:
1815
+ self.logger.warning(
1816
+ f"Encountered network client error during stream: {e}. Resetting client..."
1817
+ )
1818
+ self.client = None
1819
+ if not self.client:
1820
+ self.client = await self.get_client()
1821
+
1822
+ # Recreate chat session
1823
+ # Note: We rely on history variable being the initial history.
1824
+ # Intermediate turn state might be lost if this happens mid-conversation,
1825
+ # but this error usually happens at connection start.
1826
+ chat = self.client.aio.chats.create(
1827
+ model=model,
1828
+ history=history,
1829
+ config=GenerateContentConfig(
1830
+ system_instruction=system_prompt,
1831
+ tools=gemini_tools,
1832
+ temperature=temperature or self.temperature,
1833
+ max_output_tokens=current_max_tokens
1834
+ )
1835
+ )
1836
+ retry_count += 1
1837
+ await self._wait_with_backoff(retry_count, retry_config)
1838
+ keep_looping = True
1839
+ continue
1840
+
1841
+ if retry_count < retry_config.max_retries:
1842
+ error_msg = f"\n\n⚠️ **Streaming error (attempt {retry_count + 1}): {str(e)}. Retrying...**\n\n"
1843
+ yield error_msg
1844
+ retry_count += 1
1845
+ await self._wait_with_backoff(retry_count, retry_config)
1846
+ keep_looping = True
1847
+ continue
1848
+ else:
1849
+ yield f"\n\n❌ **Streaming failed: {str(e)}**\n"
1850
+ break
1851
+
1852
+ # Update conversation memory
1853
+ final_text = "".join(all_assistant_text)
1854
+ if final_text:
1855
+ final_assistant_message = {
1856
+ "role": "assistant", "content": [
1857
+ {"type": "text", "text": final_text}
1858
+ ]
1859
+ }
1860
+ # Extract assistant response text for conversation memory
1861
+ await self._update_conversation_memory(
1862
+ user_id,
1863
+ session_id,
1864
+ conversation_history,
1865
+ messages + [final_assistant_message],
1866
+ system_prompt,
1867
+ turn_id,
1868
+ prompt,
1869
+ final_text,
1870
+ [] # We don't easily track tool usage in stream return yet, or we could track in loop
1871
+ )
1872
+
1873
+ async def batch_ask(self, requests) -> List[AIMessage]:
1874
+ """Process multiple requests in batch."""
1875
+ # Google GenAI doesn't have a native batch API, so we process sequentially
1876
+ results = []
1877
+ for request in requests:
1878
+ result = await self.ask(**request)
1879
+ results.append(result)
1880
+ return results
1881
+
1882
+ async def ask_to_image(
1883
+ self,
1884
+ prompt: str,
1885
+ image: Union[Path, bytes],
1886
+ reference_images: Optional[Union[List[Path], List[bytes]]] = None,
1887
+ model: Union[str, GoogleModel] = None,
1888
+ max_tokens: Optional[int] = None,
1889
+ temperature: Optional[float] = None,
1890
+ structured_output: Union[type, StructuredOutputConfig] = None,
1891
+ count_objects: bool = False,
1892
+ user_id: Optional[str] = None,
1893
+ session_id: Optional[str] = None,
1894
+ no_memory: bool = False,
1895
+ ) -> AIMessage:
1896
+ """
1897
+ Ask a question to Google's Generative AI using a stateful chat session.
1898
+ """
1899
+ model = model.value if isinstance(model, GoogleModel) else model
1900
+ if not model:
1901
+ model = self.model or GoogleModel.GEMINI_2_5_FLASH.value
1902
+ turn_id = str(uuid.uuid4())
1903
+ original_prompt = prompt
1904
+
1905
+ if no_memory:
1906
+ # For no_memory mode, skip conversation memory
1907
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
1908
+ conversation_session = None
1909
+ else:
1910
+ messages, conversation_session, _ = await self._prepare_conversation_context(
1911
+ prompt, None, user_id, session_id, None
1912
+ )
1913
+
1914
+ # Prepare conversation history for Google GenAI format
1915
+ history = []
1916
+ if messages:
1917
+ for msg in messages[:-1]: # Exclude the current user message (last in list)
1918
+ role = msg['role'].lower()
1919
+ if role == 'user':
1920
+ parts = []
1921
+ for part_content in msg.get('content', []):
1922
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1923
+ parts.append(Part(text=part_content.get('text', '')))
1924
+ if parts:
1925
+ history.append(UserContent(parts=parts))
1926
+ elif role in ['assistant', 'model']:
1927
+ parts = []
1928
+ for part_content in msg.get('content', []):
1929
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
1930
+ parts.append(Part(text=part_content.get('text', '')))
1931
+ if parts:
1932
+ history.append(ModelContent(parts=parts))
1933
+
1934
+ # --- Multi-Modal Content Preparation ---
1935
+ if isinstance(image, Path):
1936
+ if not image.exists():
1937
+ raise FileNotFoundError(
1938
+ f"Image file not found: {image}"
1939
+ )
1940
+ # Load the primary image
1941
+ primary_image = Image.open(image)
1942
+ elif isinstance(image, bytes):
1943
+ primary_image = Image.open(io.BytesIO(image))
1944
+ elif isinstance(image, Image.Image):
1945
+ primary_image = image
1946
+ else:
1947
+ raise ValueError(
1948
+ "Image must be a Path, bytes, or PIL.Image object."
1949
+ )
1950
+
1951
+ # The content for the API call is a list containing images and the final prompt
1952
+ contents = [primary_image]
1953
+ if reference_images:
1954
+ for ref_path in reference_images:
1955
+ self.logger.debug(
1956
+ f"Loading reference image from: {ref_path}"
1957
+ )
1958
+ if isinstance(ref_path, Path):
1959
+ if not ref_path.exists():
1960
+ raise FileNotFoundError(
1961
+ f"Reference image file not found: {ref_path}"
1962
+ )
1963
+ contents.append(Image.open(ref_path))
1964
+ elif isinstance(ref_path, bytes):
1965
+ contents.append(Image.open(io.BytesIO(ref_path)))
1966
+ elif isinstance(ref_path, Image.Image):
1967
+ # is already a PIL.Image Object
1968
+ contents.append(ref_path)
1969
+ else:
1970
+ raise ValueError(
1971
+ "Reference Image must be a Path, bytes, or PIL.Image object."
1972
+ )
1973
+
1974
+ contents.append(prompt) # The text prompt always comes last
1975
+ generation_config = {
1976
+ "max_output_tokens": max_tokens or self.max_tokens,
1977
+ "temperature": temperature or self.temperature,
1978
+ }
1979
+ output_config = self._get_structured_config(structured_output)
1980
+ structured_output_config = output_config
1981
+ # Vision models generally don't support tools, so we focus on structured output
1982
+ if structured_output_config:
1983
+ self.logger.debug("Structured output requested for vision task.")
1984
+ self._apply_structured_output_schema(generation_config, structured_output_config)
1985
+ elif count_objects:
1986
+ # Default to JSON for structured output if not specified
1987
+ structured_output_config = StructuredOutputConfig(output_type=ObjectDetectionResult)
1988
+ self._apply_structured_output_schema(generation_config, structured_output_config)
1989
+
1990
+ # Create the stateful chat session
1991
+ chat = self.client.aio.chats.create(model=model, history=history)
1992
+ final_config = GenerateContentConfig(**generation_config)
1993
+
1994
+ # Make the primary multi-modal call
1995
+ self.logger.debug(f"Sending {len(contents)} parts to the model.")
1996
+ response = await chat.send_message(
1997
+ message=contents,
1998
+ config=final_config
1999
+ )
2000
+
2001
+ # --- Response Handling ---
2002
+ final_output = None
2003
+ if structured_output_config:
2004
+ try:
2005
+ final_output = await self._parse_structured_output(
2006
+ response.text,
2007
+ structured_output_config
2008
+ )
2009
+ except Exception as e:
2010
+ self.logger.error(
2011
+ f"Failed to parse structured output from vision model: {e}"
2012
+ )
2013
+ final_output = response.text
2014
+ elif '```json' in response.text:
2015
+ # Attempt to extract JSON from markdown code block
2016
+ try:
2017
+ final_output = self._parse_json_from_text(response.text)
2018
+ except Exception as e:
2019
+ self.logger.error(
2020
+ f"Failed to parse JSON from markdown in vision model response: {e}"
2021
+ )
2022
+ final_output = response.text
2023
+ else:
2024
+ final_output = response.text
2025
+
2026
+ final_assistant_message = {
2027
+ "role": "model", "content": [
2028
+ {"type": "text", "text": final_output}
2029
+ ]
2030
+ }
2031
+ if no_memory is False:
2032
+ await self._update_conversation_memory(
2033
+ user_id,
2034
+ session_id,
2035
+ conversation_session,
2036
+ messages + [
2037
+ {
2038
+ "role": "user",
2039
+ "content": [
2040
+ {"type": "text", "text": f"[Image Analysis]: {prompt}"}
2041
+ ]
2042
+ },
2043
+ final_assistant_message
2044
+ ],
2045
+ None,
2046
+ turn_id,
2047
+ original_prompt,
2048
+ response.text,
2049
+ []
2050
+ )
2051
+ ai_message = AIMessageFactory.from_gemini(
2052
+ response=response,
2053
+ input_text=original_prompt,
2054
+ model=model,
2055
+ user_id=user_id,
2056
+ session_id=session_id,
2057
+ turn_id=turn_id,
2058
+ structured_output=final_output if final_output != response.text else None,
2059
+ tool_calls=[]
2060
+ )
2061
+ ai_message.provider = "google_genai"
2062
+ return ai_message
2063
+
2064
+ async def generate_images(
2065
+ self,
2066
+ prompt_data: ImageGenerationPrompt,
2067
+ model: Union[str, GoogleModel] = GoogleModel.IMAGEN_3,
2068
+ reference_image: Optional[Path] = None,
2069
+ output_directory: Optional[Path] = None,
2070
+ mime_format: str = "image/jpeg",
2071
+ number_of_images: int = 1,
2072
+ user_id: Optional[str] = None,
2073
+ session_id: Optional[str] = None,
2074
+ add_watermark: bool = False
2075
+ ) -> AIMessage:
2076
+ """
2077
+ Generates images based on a text prompt using Imagen.
2078
+ """
2079
+ if prompt_data.model:
2080
+ model = GoogleModel.IMAGEN_3.value
2081
+ model = model.value if isinstance(model, GoogleModel) else model
2082
+ self.logger.info(
2083
+ f"Starting image generation with model: {model}"
2084
+ )
2085
+ if model == GoogleModel.GEMINI_2_0_IMAGE_GENERATION.value:
2086
+ image_provider = "google_genai"
2087
+ config=types.GenerateContentConfig(
2088
+ response_modalities=['TEXT', 'IMAGE']
2089
+ )
2090
+ else:
2091
+ image_provider = "google_imagen"
2092
+
2093
+ full_prompt = prompt_data.prompt
2094
+ if prompt_data.styles:
2095
+ full_prompt += ", " + ", ".join(prompt_data.styles)
2096
+
2097
+ if reference_image:
2098
+ self.logger.info(
2099
+ f"Using reference image: {reference_image}"
2100
+ )
2101
+ if not reference_image.exists():
2102
+ raise FileNotFoundError(
2103
+ f"Reference image not found: {reference_image}"
2104
+ )
2105
+ # Load the reference image
2106
+ ref_image = Image.open(reference_image)
2107
+ full_prompt = [full_prompt, ref_image]
2108
+
2109
+ config = types.GenerateImagesConfig(
2110
+ number_of_images=number_of_images,
2111
+ output_mime_type=mime_format,
2112
+ safety_filter_level="BLOCK_LOW_AND_ABOVE",
2113
+ person_generation="ALLOW_ADULT", # Or ALLOW_ALL, etc.
2114
+ aspect_ratio=prompt_data.aspect_ratio,
2115
+ )
2116
+
2117
+ try:
2118
+ start_time = time.time()
2119
+ # Use the asynchronous client for image generation
2120
+ image_response = await self.client.aio.models.generate_images(
2121
+ model=prompt_data.model,
2122
+ prompt=full_prompt,
2123
+ config=config
2124
+ )
2125
+ execution_time = time.time() - start_time
2126
+
2127
+ pil_images = []
2128
+ saved_image_paths = []
2129
+ raw_response = {} # Initialize an empty dict for the raw response
2130
+
2131
+ if image_response.generated_images:
2132
+ self.logger.info(
2133
+ f"Successfully generated {len(image_response.generated_images)} image(s)."
2134
+ )
2135
+ raw_response['generated_images'] = []
2136
+ for i, generated_image in enumerate(image_response.generated_images):
2137
+ pil_image = generated_image.image
2138
+ pil_images.append(pil_image)
2139
+
2140
+ raw_response['generated_images'].append({
2141
+ 'uri': getattr(generated_image, 'uri', None),
2142
+ 'seed': getattr(generated_image, 'seed', None)
2143
+ })
2144
+
2145
+ if output_directory:
2146
+ file_path = self._save_image(pil_image, output_directory)
2147
+ saved_image_paths.append(file_path)
2148
+
2149
+ usage = CompletionUsage(execution_time=execution_time)
2150
+ # The primary 'output' is the list of raw PIL.Image objects
2151
+ # The new 'images' attribute holds the file paths
2152
+ ai_message = AIMessageFactory.from_imagen(
2153
+ output=pil_images,
2154
+ images=saved_image_paths,
2155
+ input=full_prompt,
2156
+ model=model,
2157
+ user_id=user_id,
2158
+ session_id=session_id,
2159
+ provider=image_provider,
2160
+ usage=usage,
2161
+ raw_response=raw_response
2162
+ )
2163
+ return ai_message
2164
+
2165
+ except Exception as e:
2166
+ self.logger.error(f"Image generation failed: {e}")
2167
+ raise
2168
+
2169
+ def _find_voice_for_speaker(self, speaker: FictionalSpeaker) -> str:
2170
+ """
2171
+ Find the best voice for a speaker based on their characteristics and gender.
2172
+
2173
+ Args:
2174
+ speaker: The fictional speaker configuration
2175
+
2176
+ Returns:
2177
+ Voice name string
2178
+ """
2179
+ if not self.voice_db:
2180
+ self.logger.warning(
2181
+ "Voice database not available, using default voice"
2182
+ )
2183
+ return "erinome" # Default fallback
2184
+
2185
+ try:
2186
+ # First, try to find voices by characteristic
2187
+ characteristic_voices = self.voice_db.get_voices_by_characteristic(
2188
+ speaker.characteristic
2189
+ )
2190
+
2191
+ if characteristic_voices:
2192
+ # Filter by gender if possible
2193
+ gender_filtered = [
2194
+ v for v in characteristic_voices if v.gender == speaker.gender
2195
+ ]
2196
+ if gender_filtered:
2197
+ return gender_filtered[0].voice_name.lower()
2198
+ else:
2199
+ # Use first voice with matching characteristic regardless of gender
2200
+ return characteristic_voices[0].voice_name.lower()
2201
+
2202
+ # Fallback: find by gender only
2203
+ gender_voices = self.voice_db.get_voices_by_gender(speaker.gender)
2204
+ if gender_voices:
2205
+ self.logger.info(
2206
+ f"Found voice by gender '{speaker.gender}': {gender_voices[0].voice_name}"
2207
+ )
2208
+ return gender_voices[0].voice_name.lower()
2209
+
2210
+ # Ultimate fallback
2211
+ self.logger.warning(
2212
+ f"No voice found for speaker {speaker.name}, using default"
2213
+ )
2214
+ return "erinome"
2215
+
2216
+ except Exception as e:
2217
+ self.logger.error(
2218
+ f"Error finding voice for speaker {speaker.name}: {e}"
2219
+ )
2220
+ return "erinome"
2221
+
2222
+ async def create_conversation_script(
2223
+ self,
2224
+ report_data: ConversationalScriptConfig,
2225
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
2226
+ user_id: Optional[str] = None,
2227
+ session_id: Optional[str] = None,
2228
+ temperature: float = 0.7,
2229
+ use_structured_output: bool = False,
2230
+ max_lines: int = 20
2231
+ ) -> AIMessage:
2232
+ """
2233
+ Creates a conversation script using Google's Generative AI.
2234
+ Generates a fictional conversational script from a text report using a generative model.
2235
+ Generates a complete, TTS-ready prompt for a two-person conversation
2236
+ based on a source text report.
2237
+
2238
+ This method is designed to create a script that can be used with Google's TTS system.
2239
+
2240
+ Returns:
2241
+ A string formatted for Google's TTS `generate_content` method.
2242
+ Example:
2243
+ "Make Speaker1 sound tired and bored, and Speaker2 sound excited and happy:
2244
+
2245
+ Speaker1: So... what's on the agenda today?
2246
+ Speaker2: You're never going to guess!"
2247
+ """
2248
+ model = model.value if isinstance(model, GoogleModel) else model
2249
+ self.logger.info(
2250
+ f"Starting Conversation Script with model: {model}"
2251
+ )
2252
+ turn_id = str(uuid.uuid4())
2253
+
2254
+ report_text = report_data.report_text
2255
+ if not report_text:
2256
+ raise ValueError(
2257
+ "Report text is required for generating a conversation script."
2258
+ )
2259
+ # Calculate conversation length
2260
+ conversation_length = min(report_data.length // 50, max_lines)
2261
+ if conversation_length < 4:
2262
+ conversation_length = max_lines
2263
+ system_prompt = report_data.system_prompt or "Create a natural and engaging conversation script based on the provided report."
2264
+ context = report_data.context or "This conversation is based on a report about a specific topic. The characters will discuss the key findings and insights from the report."
2265
+ interviewer = None
2266
+ interviewee = None
2267
+ for speaker in report_data.speakers:
2268
+ if not speaker.name or not speaker.role or not speaker.characteristic:
2269
+ raise ValueError(
2270
+ "Each speaker must have a name, role, and characteristic."
2271
+ )
2272
+ # role (interviewer or interviewee) and characteristic (e.g., friendly, professional)
2273
+ if speaker.role == "interviewer":
2274
+ interviewer = speaker
2275
+ elif speaker.role == "interviewee":
2276
+ interviewee = speaker
2277
+
2278
+ if not interviewer or not interviewee:
2279
+ raise ValueError("Must have exactly one interviewer and one interviewee.")
2280
+ system_instruction = report_data.system_instruction or f"""
2281
+ You are a scriptwriter. Your task is {system_prompt} for a conversation between {interviewer.name} and {interviewee.name}. "
2282
+
2283
+ **Source Report:**"
2284
+ ---
2285
+ {report_text}
2286
+ ---
2287
+
2288
+ **context:**
2289
+ {context}
2290
+
2291
+
2292
+ **Characters:**
2293
+ 1. **{interviewer.name}**: The {interviewer.role}. Their personality is **{interviewer.characteristic}**.
2294
+ 2. **{interviewee.name}**: The {interviewee.role}. Their personality is **{interviewee.characteristic}**.
2295
+
2296
+ **Conversation Length:** {conversation_length} lines.
2297
+ **Instructions:**
2298
+ - The conversation must be based on the key findings, data, and conclusions of the source report.
2299
+ - The interviewer should ask insightful questions to guide the conversation.
2300
+ - The interviewee should provide answers and explanations derived from the report.
2301
+ - The dialogue should reflect the specified personalities of the characters.
2302
+ - The conversation should be engaging, natural, and suitable for a TTS system.
2303
+ - The script should be formatted for TTS, with clear speaker lines.
2304
+
2305
+ **Gender–Neutral Output (Strict)**
2306
+ - Do NOT infer anyone's gender or use third-person gendered pronouns or titles: he, him, his, she, her, hers, Mr., Mrs., Ms., sir, ma’am, etc.
2307
+ - If a third person must be referenced, use singular they/them/their or repeat the name/role (e.g., “the manager”, “Alex”).
2308
+ - Do not include gendered stage directions (“in a feminine/masculine voice”).
2309
+ - First/second person is fine inside dialogue (“I”, “you”), but NEVER use gendered third-person forms.
2310
+
2311
+ Before finalizing, scan and fix any gendered terms. If any banned term appears, rewrite that line to comply.
2312
+
2313
+ - **IMPORTANT**: Generate ONLY the dialogue script. Do not include headers, titles, or any text other than the speaker lines. The format must be exactly:
2314
+ {interviewer.name}: [dialogue]
2315
+ {interviewee.name}: [dialogue]
2316
+ """
2317
+ generation_config = {
2318
+ "max_output_tokens": self.max_tokens,
2319
+ "temperature": temperature or self.temperature,
2320
+ }
2321
+
2322
+ # Build contents for the stateless API call
2323
+ contents = [{
2324
+ "role": "user",
2325
+ "parts": [{"text": report_text}]
2326
+ }]
2327
+
2328
+ final_config = GenerateContentConfig(
2329
+ system_instruction=system_instruction,
2330
+ safety_settings=[
2331
+ types.SafetySetting(
2332
+ category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
2333
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
2334
+ ),
2335
+ types.SafetySetting(
2336
+ category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
2337
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
2338
+ ),
2339
+ ],
2340
+ tools=None, # No tools needed for conversation script:
2341
+ **generation_config
2342
+ )
2343
+
2344
+ # Make a stateless call to the model
2345
+ if not self.client:
2346
+ self.client = await self.get_client()
2347
+ # response = await self.client.aio.models.generate_content(
2348
+ # model=model,
2349
+ # contents=contents,
2350
+ # config=final_config
2351
+ # )
2352
+ sync_generate_content = partial(
2353
+ self.client.models.generate_content,
2354
+ model=model,
2355
+ contents=contents,
2356
+ config=final_config
2357
+ )
2358
+ # Run the synchronous function in a separate thread
2359
+ response = await asyncio.to_thread(sync_generate_content)
2360
+ # Extract the generated script text
2361
+ script_text = response.text if hasattr(response, 'text') else str(response)
2362
+ structured_output = script_text
2363
+ if use_structured_output:
2364
+ self.logger.info("Creating structured output for TTS system...")
2365
+ try:
2366
+ # Map speakers to voices
2367
+ speaker_configs = []
2368
+ for speaker in report_data.speakers:
2369
+ voice = self._find_voice_for_speaker(speaker)
2370
+ speaker_configs.append(
2371
+ SpeakerConfig(name=speaker.name, voice=voice)
2372
+ )
2373
+ self.logger.notice(
2374
+ f"Assigned voice '{voice}' to speaker '{speaker.name}'"
2375
+ )
2376
+ structured_output = SpeechGenerationPrompt(
2377
+ prompt=script_text,
2378
+ speakers=speaker_configs
2379
+ )
2380
+ except Exception as e:
2381
+ self.logger.error(
2382
+ f"Failed to create structured output: {e}"
2383
+ )
2384
+ # Continue without structured output rather than failing
2385
+
2386
+ # Create the AIMessage response using the factory
2387
+ ai_message = AIMessageFactory.from_gemini(
2388
+ response=response,
2389
+ input_text=report_text,
2390
+ model=model,
2391
+ user_id=user_id,
2392
+ session_id=session_id,
2393
+ turn_id=turn_id,
2394
+ structured_output=structured_output,
2395
+ tool_calls=[]
2396
+ )
2397
+ ai_message.provider = "google_genai"
2398
+
2399
+ return ai_message
2400
+
2401
+ async def generate_speech(
2402
+ self,
2403
+ prompt_data: SpeechGenerationPrompt,
2404
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH_TTS,
2405
+ output_directory: Optional[Path] = None,
2406
+ system_prompt: Optional[str] = None,
2407
+ temperature: float = 0.7,
2408
+ mime_format: str = "audio/wav", # or "audio/mpeg", "audio/webm"
2409
+ user_id: Optional[str] = None,
2410
+ session_id: Optional[str] = None,
2411
+ max_retries: int = 3,
2412
+ retry_delay: float = 1.0
2413
+ ) -> AIMessage:
2414
+ """
2415
+ Generates speech from text using either a single voice or multiple voices.
2416
+ """
2417
+ start_time = time.time()
2418
+ if prompt_data.model:
2419
+ model = prompt_data.model
2420
+ model = model.value if isinstance(model, GoogleModel) else model
2421
+ self.logger.info(
2422
+ f"Starting Speech generation with model: {model}"
2423
+ )
2424
+
2425
+ # Validation of voices and fallback logic before creating the SpeechConfig:
2426
+ valid_voices = {v.value for v in TTSVoice}
2427
+ processed_speakers = []
2428
+ for speaker in prompt_data.speakers:
2429
+ final_voice = speaker.voice
2430
+ if speaker.voice not in valid_voices:
2431
+ self.logger.warning(
2432
+ f"Invalid voice '{speaker.voice}' for speaker '{speaker.name}'. "
2433
+ "Using default voice instead."
2434
+ )
2435
+ gender = speaker.gender.lower() if speaker.gender else 'female'
2436
+ final_voice = 'zephyr' if gender == 'female' else 'charon'
2437
+ processed_speakers.append(
2438
+ SpeakerConfig(name=speaker.name, voice=final_voice, gender=speaker.gender)
2439
+ )
2440
+
2441
+ speech_config = None
2442
+ if len(processed_speakers) == 1:
2443
+ # Single-speaker configuration
2444
+ speaker = processed_speakers[0]
2445
+ gender = speaker.gender or 'female'
2446
+ default_voice = 'Charon' if gender == 'female' else 'Puck'
2447
+ voice = speaker.voice or default_voice
2448
+ self.logger.info(f"Using single voice: {voice}")
2449
+ speech_config = types.SpeechConfig(
2450
+ voice_config=types.VoiceConfig(
2451
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice)
2452
+ ),
2453
+ language_code=prompt_data.language or "en-US" # Default to US English
2454
+ )
2455
+ else:
2456
+ # Multi-speaker configuration
2457
+ self.logger.info(
2458
+ f"Using multiple voices: {[s.voice for s in processed_speakers]}"
2459
+ )
2460
+ speaker_voice_configs = [
2461
+ types.SpeakerVoiceConfig(
2462
+ speaker=s.name,
2463
+ voice_config=types.VoiceConfig(
2464
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
2465
+ voice_name=s.voice
2466
+ )
2467
+ )
2468
+ ) for s in processed_speakers
2469
+ ]
2470
+ speech_config = types.SpeechConfig(
2471
+ multi_speaker_voice_config=types.MultiSpeakerVoiceConfig(
2472
+ speaker_voice_configs=speaker_voice_configs
2473
+ ),
2474
+ language_code=prompt_data.language or "en-US" # Default to US English
2475
+ )
2476
+
2477
+ config = types.GenerateContentConfig(
2478
+ response_modalities=["AUDIO"],
2479
+ speech_config=speech_config,
2480
+ system_instruction=system_prompt,
2481
+ temperature=temperature
2482
+ )
2483
+ # Retry logic for network errors
2484
+ if not self.client:
2485
+ self.client = await self.get_client()
2486
+ # chat = self.client.aio.chats.create(model=model, history=None, config=config)
2487
+ for attempt in range(max_retries + 1):
2488
+
2489
+ try:
2490
+ if attempt > 0:
2491
+ delay = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
2492
+ self.logger.info(
2493
+ f"Retrying speech (attempt {attempt + 1}/{max_retries + 1}) after {delay}s delay..."
2494
+ )
2495
+ await asyncio.sleep(delay)
2496
+ # response = await self.client.aio.models.generate_content(
2497
+ # model=model,
2498
+ # contents=prompt_data.prompt,
2499
+ # config=config,
2500
+ # )
2501
+ sync_generate_content = partial(
2502
+ self.client.models.generate_content,
2503
+ model=model,
2504
+ contents=prompt_data.prompt,
2505
+ config=config
2506
+ )
2507
+ # Run the synchronous function in a separate thread
2508
+ response = await asyncio.to_thread(sync_generate_content)
2509
+ # Robust audio data extraction with proper validation
2510
+ audio_data = self._extract_audio_data(response)
2511
+ if audio_data is None:
2512
+ # Log the response structure for debugging
2513
+ self.logger.error(f"Failed to extract audio data from response")
2514
+ self.logger.debug(f"Response type: {type(response)}")
2515
+ if hasattr(response, 'candidates'):
2516
+ self.logger.debug(f"Candidates count: {len(response.candidates) if response.candidates else 0}")
2517
+ if response.candidates and len(response.candidates) > 0:
2518
+ candidate = response.candidates[0]
2519
+ self.logger.debug(f"Candidate type: {type(candidate)}")
2520
+ self.logger.debug(f"Candidate has content: {hasattr(candidate, 'content')}")
2521
+ if hasattr(candidate, 'content'):
2522
+ content = candidate.content
2523
+ self.logger.debug(f"Content is None: {content is None}")
2524
+ if content:
2525
+ self.logger.debug(f"Content has parts: {hasattr(content, 'parts')}")
2526
+ if hasattr(content, 'parts'):
2527
+ self.logger.debug(f"Parts count: {len(content.parts) if content.parts else 0}")
2528
+
2529
+ raise SpeechGenerationError(
2530
+ "No audio data found in response. The speech generation may have failed or "
2531
+ "the model may not support speech generation for this request."
2532
+ )
2533
+
2534
+ saved_file_paths = []
2535
+
2536
+ if output_directory:
2537
+ output_directory.mkdir(parents=True, exist_ok=True)
2538
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
2539
+ file_path = output_directory / f"generated_speech_{timestamp}.wav"
2540
+
2541
+ self._save_audio_file(audio_data, file_path, mime_format)
2542
+ saved_file_paths.append(file_path)
2543
+ self.logger.info(
2544
+ f"Saved speech to {file_path}"
2545
+ )
2546
+
2547
+ execution_time = time.time() - start_time
2548
+ usage = CompletionUsage(
2549
+ execution_time=execution_time,
2550
+ # Speech API does not return token counts
2551
+ input_tokens=len(prompt_data.prompt), # Approximation
2552
+ )
2553
+
2554
+ ai_message = AIMessageFactory.from_speech(
2555
+ output=audio_data, # The raw PCM audio data
2556
+ files=saved_file_paths,
2557
+ input=prompt_data.prompt,
2558
+ model=model,
2559
+ provider="google_genai",
2560
+ usage=usage,
2561
+ user_id=user_id,
2562
+ session_id=session_id,
2563
+ raw_response=None # Response object isn't easily serializable
2564
+ )
2565
+ return ai_message
2566
+
2567
+ except (
2568
+ aiohttp.ClientPayloadError,
2569
+ aiohttp.ClientConnectionError,
2570
+ aiohttp.ClientResponseError,
2571
+ aiohttp.ServerTimeoutError,
2572
+ ConnectionResetError,
2573
+ TimeoutError,
2574
+ asyncio.TimeoutError
2575
+ ) as network_error:
2576
+ error_msg = str(network_error)
2577
+
2578
+ # Specific handling for different network errors
2579
+ if "TransferEncodingError" in error_msg:
2580
+ self.logger.warning(
2581
+ f"Transfer encoding error on attempt {attempt + 1}: {error_msg}")
2582
+ elif "Connection reset by peer" in error_msg:
2583
+ self.logger.warning(
2584
+ f"Connection reset on attempt {attempt + 1}: Server closed connection")
2585
+ elif "timeout" in error_msg.lower():
2586
+ self.logger.warning(
2587
+ f"Timeout error on attempt {attempt + 1}: {error_msg}")
2588
+ else:
2589
+ self.logger.warning(
2590
+ f"Network error on attempt {attempt + 1}: {error_msg}"
2591
+ )
2592
+
2593
+ if attempt < max_retries:
2594
+ self.logger.debug(
2595
+ f"Will retry in {retry_delay * (2 ** attempt)}s..."
2596
+ )
2597
+ continue
2598
+ else:
2599
+ # Max retries exceeded
2600
+ self.logger.error(
2601
+ f"Speech generation failed after {max_retries + 1} attempts"
2602
+ )
2603
+ raise SpeechGenerationError(
2604
+ f"Speech generation failed after {max_retries + 1} attempts. "
2605
+ f"Last error: {error_msg}. This is typically a temporary network issue - please try again."
2606
+ ) from network_error
2607
+
2608
+ except Exception as e:
2609
+ # Non-network errors - don't retry
2610
+ error_msg = str(e)
2611
+ self.logger.error(
2612
+ f"Speech generation failed with non-retryable error: {error_msg}"
2613
+ )
2614
+
2615
+ # Provide helpful error messages based on error type
2616
+ if "quota" in error_msg.lower() or "rate limit" in error_msg.lower():
2617
+ raise SpeechGenerationError(
2618
+ f"API quota or rate limit exceeded: {error_msg}. Please try again later."
2619
+ ) from e
2620
+ elif "permission" in error_msg.lower() or "unauthorized" in error_msg.lower():
2621
+ raise SpeechGenerationError(
2622
+ f"Authorization error: {error_msg}. Please check your API credentials."
2623
+ ) from e
2624
+ elif "model" in error_msg.lower():
2625
+ raise SpeechGenerationError(
2626
+ f"Model error: {error_msg}. The model '{model}' may not support speech generation."
2627
+ ) from e
2628
+ else:
2629
+ raise SpeechGenerationError(
2630
+ f"Speech generation failed: {error_msg}"
2631
+ ) from e
2632
+
2633
+ def _extract_audio_data(self, response):
2634
+ """
2635
+ Robustly extract audio data from Google GenAI response.
2636
+ Similar to the text extraction pattern used elsewhere in the codebase.
2637
+ """
2638
+ try:
2639
+ # First attempt: Direct access to expected structure
2640
+ if (hasattr(response, 'candidates') and
2641
+ response.candidates and
2642
+ len(response.candidates) > 0 and
2643
+ hasattr(response.candidates[0], 'content') and
2644
+ response.candidates[0].content and
2645
+ hasattr(response.candidates[0].content, 'parts') and
2646
+ response.candidates[0].content.parts and
2647
+ len(response.candidates[0].content.parts) > 0):
2648
+
2649
+ for part in response.candidates[0].content.parts:
2650
+ # Check for inline_data with audio data
2651
+ if (hasattr(part, 'inline_data') and
2652
+ part.inline_data and
2653
+ hasattr(part.inline_data, 'data') and
2654
+ part.inline_data.data):
2655
+ self.logger.debug("Found audio data in inline_data.data")
2656
+ return part.inline_data.data
2657
+
2658
+ # Alternative: Check for direct data attribute
2659
+ if hasattr(part, 'data') and part.data:
2660
+ self.logger.debug("Found audio data in part.data")
2661
+ return part.data
2662
+
2663
+ # Alternative: Check for binary data
2664
+ if hasattr(part, 'binary') and part.binary:
2665
+ self.logger.debug("Found audio data in part.binary")
2666
+ return part.binary
2667
+
2668
+ self.logger.warning("No audio data found in expected response structure")
2669
+ return None
2670
+
2671
+ except Exception as e:
2672
+ self.logger.error(f"Audio data extraction failed: {e}")
2673
+ return None
2674
+
2675
+ async def generate_videos(
2676
+ self,
2677
+ prompt: VideoGenerationPrompt,
2678
+ reference_image: Optional[Path] = None,
2679
+ output_directory: Optional[Path] = None,
2680
+ mime_format: str = "video/mp4",
2681
+ model: Union[str, GoogleModel] = GoogleModel.VEO_3_0,
2682
+ ) -> AIMessage:
2683
+ """
2684
+ Generate a video using the specified model and prompt.
2685
+ """
2686
+ if prompt.model:
2687
+ model = prompt.model
2688
+ model = model.value if isinstance(model, GoogleModel) else model
2689
+ if model not in [GoogleModel.VEO_2_0.value, GoogleModel.VEO_3_0.value]:
2690
+ raise ValueError(
2691
+ "Generate Videos are only supported with VEO 2.0 or VEO 3.0 models."
2692
+ )
2693
+ self.logger.info(
2694
+ f"Starting Video generation with model: {model}"
2695
+ )
2696
+ if output_directory:
2697
+ output_directory.mkdir(parents=True, exist_ok=True)
2698
+ else:
2699
+ output_directory = BASE_DIR.joinpath('static', 'generated_videos')
2700
+ args = {
2701
+ "prompt": prompt.prompt,
2702
+ "model": model,
2703
+ }
2704
+
2705
+ if reference_image:
2706
+ # if a reference image is used, only Veo2 is supported:
2707
+ self.logger.info(
2708
+ f"Veo 3.0 does not support reference images, using VEO 2.0 instead."
2709
+ )
2710
+ model = GoogleModel.VEO_2_0.value
2711
+ self.logger.info(
2712
+ f"Using reference image: {reference_image}"
2713
+ )
2714
+ if not reference_image.exists():
2715
+ raise FileNotFoundError(
2716
+ f"Reference image not found: {reference_image}"
2717
+ )
2718
+ # Load the reference image
2719
+ ref_image = Image.open(reference_image)
2720
+ args['image'] = types.Image(image_bytes=ref_image)
2721
+
2722
+ start_time = time.time()
2723
+ operation = self.client.models.generate_videos(
2724
+ **args,
2725
+ config=types.GenerateVideosConfig(
2726
+ aspect_ratio=prompt.aspect_ratio or "16:9", # Default to 16:9
2727
+ negative_prompt=prompt.negative_prompt, # Optional negative prompt
2728
+ number_of_videos=prompt.number_of_videos, # Number of videos to generate
2729
+ )
2730
+ )
2731
+
2732
+ print("Video generation job started. Waiting for completion...", end="")
2733
+ spinner_chars = ['|', '/', '-', '\\']
2734
+ check_interval = 10 # Check status every 10 seconds
2735
+ spinner_index = 0
2736
+
2737
+ # This loop checks the job status every 10 seconds
2738
+ while not operation.done:
2739
+ # This inner loop runs the spinner animation for the check_interval
2740
+ for _ in range(check_interval):
2741
+ # Write the spinner character to the console
2742
+ sys.stdout.write(
2743
+ f"\rVideo generation job started. Waiting for completion... {spinner_chars[spinner_index]}"
2744
+ )
2745
+ sys.stdout.flush()
2746
+ spinner_index = (spinner_index + 1) % len(spinner_chars)
2747
+ time.sleep(1) # Animate every second
2748
+
2749
+ # After 10 seconds, get the updated operation status
2750
+ operation = self.client.operations.get(operation)
2751
+
2752
+ print("\rVideo generation job completed. ", end="")
2753
+
2754
+ for n, generated_video in enumerate(operation.result.generated_videos):
2755
+ # Download the generated videos
2756
+ # bytes of the original MP4
2757
+ mp4_bytes = self.client.files.download(file=generated_video.video)
2758
+ video_path = self._save_video_file(
2759
+ mp4_bytes,
2760
+ output_directory,
2761
+ video_number=n,
2762
+ mime_format=mime_format
2763
+ )
2764
+ execution_time = time.time() - start_time
2765
+ usage = CompletionUsage(
2766
+ execution_time=execution_time,
2767
+ # Video API does not return token counts
2768
+ input_tokens=len(prompt.prompt), # Approximation
2769
+ )
2770
+
2771
+ ai_message = AIMessageFactory.from_video(
2772
+ output=operation, # The raw Video object
2773
+ files=[video_path],
2774
+ input=prompt.prompt,
2775
+ model=model,
2776
+ provider="google_genai",
2777
+ usage=usage,
2778
+ user_id=None,
2779
+ session_id=None,
2780
+ raw_response=None # Response object isn't easily serializable
2781
+ )
2782
+ return ai_message
2783
+
2784
+ async def _deep_research_ask(
2785
+ self,
2786
+ prompt: str,
2787
+ background: bool = False,
2788
+ file_search_store_names: Optional[List[str]] = None,
2789
+ user_id: Optional[str] = None,
2790
+ session_id: Optional[str] = None
2791
+ ) -> AIMessage:
2792
+ """
2793
+ Perform deep research using Google's interactions.create() API.
2794
+
2795
+ Note: This is a stub implementation. Full implementation requires the
2796
+ Google Gen AI interactions SDK which uses a different API than the
2797
+ standard models.generate_content().
2798
+ """
2799
+ self.logger.warning(
2800
+ "Google Deep Research is not yet fully implemented. "
2801
+ "This feature requires the interactions API which is currently in preview. "
2802
+ "Falling back to standard ask() behavior for now."
2803
+ )
2804
+ # TODO: Implement using client.interactions.create() when SDK supports it
2805
+ # For now, fall back to regular ask without deep_research flag
2806
+ return await self.ask(
2807
+ prompt=prompt,
2808
+ user_id=user_id,
2809
+ session_id=session_id,
2810
+ deep_research=False # Prevent infinite recursion
2811
+ )
2812
+
2813
+ async def question(
2814
+ self,
2815
+ prompt: str,
2816
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
2817
+ max_tokens: Optional[int] = None,
2818
+ temperature: Optional[float] = None,
2819
+ files: Optional[List[Union[str, Path]]] = None,
2820
+ user_id: Optional[str] = None,
2821
+ session_id: Optional[str] = None,
2822
+ system_prompt: Optional[str] = None,
2823
+ structured_output: Union[type, StructuredOutputConfig] = None,
2824
+ use_internal_tools: bool = False, # New parameter to control internal tools
2825
+ ) -> AIMessage:
2826
+ """
2827
+ Ask a question to Google's Generative AI in a stateless manner,
2828
+ without conversation history and with optional internal tools.
2829
+
2830
+ Args:
2831
+ prompt (str): The input prompt for the model.
2832
+ model (Union[str, GoogleModel]): The model to use, defaults to GEMINI_2_5_FLASH.
2833
+ max_tokens (int): Maximum number of tokens in the response.
2834
+ temperature (float): Sampling temperature for response generation.
2835
+ files (Optional[List[Union[str, Path]]]): Optional files to include in the request.
2836
+ system_prompt (Optional[str]): Optional system prompt to guide the model.
2837
+ structured_output (Union[type, StructuredOutputConfig]): Optional structured output configuration.
2838
+ user_id (Optional[str]): Optional user identifier for tracking.
2839
+ session_id (Optional[str]): Optional session identifier for tracking.
2840
+ use_internal_tools (bool): If True, Gemini's built-in tools (e.g., Google Search)
2841
+ will be made available to the model. Defaults to False.
2842
+ """
2843
+ self.logger.info(
2844
+ f"Initiating RAG pipeline for prompt: '{prompt[:50]}...'"
2845
+ )
2846
+
2847
+ model = model.value if isinstance(model, GoogleModel) else model
2848
+ turn_id = str(uuid.uuid4())
2849
+ original_prompt = prompt
2850
+
2851
+ output_config = self._get_structured_config(structured_output)
2852
+
2853
+ generation_config = {
2854
+ "max_output_tokens": max_tokens or self.max_tokens,
2855
+ "temperature": temperature or self.temperature,
2856
+ }
2857
+
2858
+ if output_config:
2859
+ self._apply_structured_output_schema(generation_config, output_config)
2860
+
2861
+ tools = None
2862
+ if use_internal_tools:
2863
+ tools = self._build_tools("builtin_tools") # Only built-in tools
2864
+ self.logger.debug(
2865
+ f"Enabled internal tool usage."
2866
+ )
2867
+
2868
+ # Build contents for the stateless call
2869
+ contents = []
2870
+ if files:
2871
+ for file_path in files:
2872
+ # In a real scenario, you'd handle file uploads to Gemini properly
2873
+ # This is a placeholder for file content
2874
+ contents.append(
2875
+ {
2876
+ "part": {
2877
+ "inline_data": {
2878
+ "mime_type": "application/octet-stream",
2879
+ "data": "BASE64_ENCODED_FILE_CONTENT"
2880
+ }
2881
+ }
2882
+ }
2883
+ )
2884
+
2885
+ # Add the user prompt as the first part
2886
+ contents.append({
2887
+ "role": "user",
2888
+ "parts": [{"text": prompt}]
2889
+ })
2890
+
2891
+ all_tool_calls = [] # To capture any tool calls made by internal tools
2892
+
2893
+ final_config = GenerateContentConfig(
2894
+ system_instruction=system_prompt,
2895
+ tools=tools,
2896
+ **generation_config
2897
+ )
2898
+
2899
+ response = await self.client.aio.models.generate_content(
2900
+ model=model,
2901
+ contents=contents,
2902
+ config=final_config
2903
+ )
2904
+
2905
+ # Handle potential internal tool calls if they are part of the direct generate_content response
2906
+ # Gemini can sometimes decide to use internal tools even without explicit function calling setup
2907
+ # if the tools are broadly enabled (e.g., through a general 'tool' parameter).
2908
+ # This part assumes Gemini's 'generate_content' directly returns tool calls if it uses them.
2909
+ if use_internal_tools and response.candidates and response.candidates[0].content.parts:
2910
+ function_calls = [
2911
+ part.function_call
2912
+ for part in response.candidates[0].content.parts
2913
+ if hasattr(part, 'function_call') and part.function_call
2914
+ ]
2915
+ if function_calls:
2916
+ tool_call_objects = []
2917
+ for fc in function_calls:
2918
+ tc = ToolCall(
2919
+ id=f"call_{uuid.uuid4().hex[:8]}",
2920
+ name=fc.name,
2921
+ arguments=dict(fc.args)
2922
+ )
2923
+ tool_call_objects.append(tc)
2924
+
2925
+ start_time = time.time()
2926
+ tool_execution_tasks = [
2927
+ self._execute_tool(fc.name, dict(fc.args)) for fc in function_calls
2928
+ ]
2929
+ tool_results = await asyncio.gather(
2930
+ *tool_execution_tasks,
2931
+ return_exceptions=True
2932
+ )
2933
+ execution_time = time.time() - start_time
2934
+
2935
+ for tc, result in zip(tool_call_objects, tool_results):
2936
+ tc.execution_time = execution_time / len(tool_call_objects)
2937
+ if isinstance(result, Exception):
2938
+ tc.error = str(result)
2939
+ else:
2940
+ tc.result = result
2941
+
2942
+ all_tool_calls.extend(tool_call_objects)
2943
+ pass # We're not doing a multi-turn here for stateless
2944
+
2945
+ final_output = None
2946
+ if output_config:
2947
+ try:
2948
+ final_output = await self._parse_structured_output(
2949
+ response.text,
2950
+ output_config
2951
+ )
2952
+ except Exception:
2953
+ final_output = response.text
2954
+
2955
+ ai_message = AIMessageFactory.from_gemini(
2956
+ response=response,
2957
+ input_text=original_prompt,
2958
+ model=model,
2959
+ user_id=user_id,
2960
+ session_id=session_id,
2961
+ turn_id=turn_id,
2962
+ structured_output=final_output if final_output != response.text else None,
2963
+ tool_calls=all_tool_calls
2964
+ )
2965
+ ai_message.provider = "google_genai"
2966
+
2967
+ return ai_message
2968
+
2969
+ async def summarize_text(
2970
+ self,
2971
+ text: str,
2972
+ max_length: int = 500,
2973
+ min_length: int = 100,
2974
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
2975
+ temperature: Optional[float] = None,
2976
+ user_id: Optional[str] = None,
2977
+ session_id: Optional[str] = None,
2978
+ ) -> AIMessage:
2979
+ """
2980
+ Generates a summary for a given text in a stateless manner.
2981
+
2982
+ Args:
2983
+ text (str): The text content to summarize.
2984
+ max_length (int): The maximum desired character length for the summary.
2985
+ min_length (int): The minimum desired character length for the summary.
2986
+ model (Union[str, GoogleModel]): The model to use.
2987
+ temperature (float): Sampling temperature for response generation.
2988
+ user_id (Optional[str]): Optional user identifier for tracking.
2989
+ session_id (Optional[str]): Optional session identifier for tracking.
2990
+ """
2991
+ self.logger.info(
2992
+ f"Generating summary for text: '{text[:50]}...'"
2993
+ )
2994
+
2995
+ model = model.value if isinstance(model, GoogleModel) else model
2996
+ turn_id = str(uuid.uuid4())
2997
+
2998
+ # Define the specific system prompt for summarization
2999
+ system_prompt = f"""
3000
+ Your job is to produce a final summary from the following text and identify the main theme.
3001
+ - The summary should be concise and to the point.
3002
+ - The summary should be no longer than {max_length} characters and no less than {min_length} characters.
3003
+ - The summary should be in a single paragraph.
3004
+ """
3005
+
3006
+ generation_config = {
3007
+ "max_output_tokens": self.max_tokens,
3008
+ "temperature": temperature or self.temperature,
3009
+ }
3010
+
3011
+ # Build contents for the stateless call. The 'prompt' is the text to be summarized.
3012
+ contents = [{
3013
+ "role": "user",
3014
+ "parts": [{"text": text}]
3015
+ }]
3016
+
3017
+ final_config = GenerateContentConfig(
3018
+ system_instruction=system_prompt,
3019
+ tools=None, # No tools needed for summarization
3020
+ **generation_config
3021
+ )
3022
+
3023
+ # Make a stateless call to the model
3024
+ response = await self.client.aio.models.generate_content(
3025
+ model=model,
3026
+ contents=contents,
3027
+ config=final_config
3028
+ )
3029
+
3030
+ # Create the AIMessage response using the factory
3031
+ ai_message = AIMessageFactory.from_gemini(
3032
+ response=response,
3033
+ input_text=text,
3034
+ model=model,
3035
+ user_id=user_id,
3036
+ session_id=session_id,
3037
+ turn_id=turn_id,
3038
+ structured_output=None,
3039
+ tool_calls=[]
3040
+ )
3041
+ ai_message.provider = "google_genai"
3042
+
3043
+ return ai_message
3044
+
3045
+ async def translate_text(
3046
+ self,
3047
+ text: str,
3048
+ target_lang: str,
3049
+ source_lang: Optional[str] = None,
3050
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
3051
+ temperature: Optional[float] = 0.2,
3052
+ user_id: Optional[str] = None,
3053
+ session_id: Optional[str] = None,
3054
+ ) -> AIMessage:
3055
+ """
3056
+ Translates a given text from a source language to a target language.
3057
+
3058
+ Args:
3059
+ text (str): The text content to translate.
3060
+ target_lang (str): The ISO code for the target language (e.g., 'es', 'fr').
3061
+ source_lang (Optional[str]): The ISO code for the source language.
3062
+ If None, the model will attempt to detect it.
3063
+ model (Union[str, GoogleModel]): The model to use. Defaults to GEMINI_2_5_FLASH,
3064
+ which is recommended for speed.
3065
+ temperature (float): Sampling temperature for response generation.
3066
+ user_id (Optional[str]): Optional user identifier for tracking.
3067
+ session_id (Optional[str]): Optional session identifier for tracking.
3068
+ """
3069
+ self.logger.info(
3070
+ f"Translating text to '{target_lang}': '{text[:50]}...'"
3071
+ )
3072
+
3073
+ model = model.value if isinstance(model, GoogleModel) else model
3074
+ turn_id = str(uuid.uuid4())
3075
+
3076
+ # Construct the system prompt for translation
3077
+ if source_lang:
3078
+ prompt_instruction = (
3079
+ f"Translate the following text from {source_lang} to {target_lang}. "
3080
+ "Only return the translated text, without any additional comments or explanations."
3081
+ )
3082
+ else:
3083
+ prompt_instruction = (
3084
+ f"First, detect the source language of the following text. Then, translate it to {target_lang}. "
3085
+ "Only return the translated text, without any additional comments or explanations."
3086
+ )
3087
+
3088
+ generation_config = {
3089
+ "max_output_tokens": self.max_tokens,
3090
+ "temperature": temperature or self.temperature,
3091
+ }
3092
+
3093
+ # Build contents for the stateless API call
3094
+ contents = [{
3095
+ "role": "user",
3096
+ "parts": [{"text": text}]
3097
+ }]
3098
+
3099
+ final_config = GenerateContentConfig(
3100
+ system_instruction=prompt_instruction,
3101
+ tools=None, # No tools needed for translation
3102
+ **generation_config
3103
+ )
3104
+
3105
+ # Make a stateless call to the model
3106
+ response = await self.client.aio.models.generate_content(
3107
+ model=model,
3108
+ contents=contents,
3109
+ config=final_config
3110
+ )
3111
+
3112
+ # Create the AIMessage response using the factory
3113
+ ai_message = AIMessageFactory.from_gemini(
3114
+ response=response,
3115
+ input_text=text,
3116
+ model=model,
3117
+ user_id=user_id,
3118
+ session_id=session_id,
3119
+ turn_id=turn_id,
3120
+ structured_output=None,
3121
+ tool_calls=[]
3122
+ )
3123
+ ai_message.provider = "google_genai"
3124
+
3125
+ return ai_message
3126
+
3127
+ async def extract_key_points(
3128
+ self,
3129
+ text: str,
3130
+ num_points: int = 5,
3131
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH, # Changed to GoogleModel
3132
+ temperature: Optional[float] = 0.3,
3133
+ user_id: Optional[str] = None,
3134
+ session_id: Optional[str] = None,
3135
+ ) -> AIMessage:
3136
+ """
3137
+ Extract *num_points* bullet-point key ideas from *text* (stateless).
3138
+ """
3139
+ self.logger.info(
3140
+ f"Extracting {num_points} key points from text: '{text[:50]}...'"
3141
+ )
3142
+
3143
+ model = model.value if isinstance(model, GoogleModel) else model
3144
+ turn_id = str(uuid.uuid4())
3145
+
3146
+ system_instruction = ( # Changed to system_instruction for Google GenAI
3147
+ f"Extract the {num_points} most important key points from the following text.\n"
3148
+ "- Present each point as a clear, concise bullet point (•).\n"
3149
+ "- Focus on the main ideas and significant information.\n"
3150
+ "- Each point should be self-contained and meaningful.\n"
3151
+ "- Order points by importance (most important first)."
3152
+ )
3153
+
3154
+ # Build contents for the stateless API call
3155
+ contents = [{
3156
+ "role": "user",
3157
+ "parts": [{"text": text}]
3158
+ }]
3159
+
3160
+ generation_config = {
3161
+ "max_output_tokens": self.max_tokens,
3162
+ "temperature": temperature or self.temperature,
3163
+ }
3164
+
3165
+ final_config = GenerateContentConfig(
3166
+ system_instruction=system_instruction,
3167
+ tools=None, # No tools needed for this task
3168
+ **generation_config
3169
+ )
3170
+
3171
+ # Make a stateless call to the model
3172
+ response = await self.client.aio.models.generate_content(
3173
+ model=model,
3174
+ contents=contents,
3175
+ config=final_config
3176
+ )
3177
+
3178
+ # Create the AIMessage response using the factory
3179
+ ai_message = AIMessageFactory.from_gemini(
3180
+ response=response,
3181
+ input_text=text,
3182
+ model=model,
3183
+ user_id=user_id,
3184
+ session_id=session_id,
3185
+ turn_id=turn_id,
3186
+ structured_output=None, # No structured output explicitly requested
3187
+ tool_calls=[] # No tool calls for this method
3188
+ )
3189
+ ai_message.provider = "google_genai" # Set provider
3190
+
3191
+ return ai_message
3192
+
3193
+ async def analyze_sentiment(
3194
+ self,
3195
+ text: str,
3196
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
3197
+ temperature: Optional[float] = 0.1,
3198
+ user_id: Optional[str] = None,
3199
+ session_id: Optional[str] = None,
3200
+ use_structured: bool = False,
3201
+ ) -> AIMessage:
3202
+ """
3203
+ Perform sentiment analysis on text and return a structured or unstructured response.
3204
+
3205
+ Args:
3206
+ text (str): The text to analyze.
3207
+ model (Union[GoogleModel, str]): The model to use for the analysis.
3208
+ temperature (float): Sampling temperature for response generation.
3209
+ user_id (Optional[str]): Optional user identifier for tracking.
3210
+ session_id (Optional[str]): Optional session identifier for tracking.
3211
+ use_structured (bool): If True, forces a structured JSON output matching
3212
+ the SentimentAnalysis model. Defaults to False.
3213
+ """
3214
+ self.logger.info(f"Analyzing sentiment for text: '{text[:50]}...'")
3215
+
3216
+ model_name = model.value if isinstance(model, GoogleModel) else model
3217
+ turn_id = str(uuid.uuid4())
3218
+
3219
+ system_instruction = ""
3220
+ generation_config = {
3221
+ "max_output_tokens": self.max_tokens,
3222
+ "temperature": temperature or self.temperature,
3223
+ }
3224
+ structured_output_model = None
3225
+
3226
+ if use_structured:
3227
+ # ✍️ Generate a prompt to force JSON output matching the Pydantic schema
3228
+ schema = SentimentAnalysis.model_json_schema()
3229
+ system_instruction = (
3230
+ "You are an expert in sentiment analysis. Analyze the following text and provide a structured JSON response. "
3231
+ "Your response MUST be a valid JSON object that conforms to the following JSON Schema. "
3232
+ "Do not include any other text, explanations, or markdown formatting like ```json ... ```.\n\n"
3233
+ f"JSON Schema:\n{self._json.dumps(schema, indent=2)}"
3234
+ )
3235
+ # Enable Gemini's JSON mode for reliable structured output
3236
+ generation_config["response_mime_type"] = "application/json"
3237
+ structured_output_model = SentimentAnalysis
3238
+ else:
3239
+ # The original prompt for a human-readable, unstructured response
3240
+ system_instruction = (
3241
+ "Analyze the sentiment of the following text and provide a structured response.\n"
3242
+ "Your response must include:\n"
3243
+ "1. Overall sentiment (Positive, Negative, Neutral, or Mixed)\n"
3244
+ "2. Confidence level (High, Medium, Low)\n"
3245
+ "3. Key emotional indicators found in the text\n"
3246
+ "4. Brief explanation of your analysis\n\n"
3247
+ "Format your answer clearly with numbered sections."
3248
+ )
3249
+
3250
+ contents = [{"role": "user", "parts": [{"text": text}]}]
3251
+
3252
+ final_config = GenerateContentConfig(
3253
+ system_instruction={"role": "system", "parts": [{"text": system_instruction}]},
3254
+ tools=None,
3255
+ **generation_config,
3256
+ )
3257
+
3258
+ response = await self.client.aio.models.generate_content(
3259
+ model=model_name,
3260
+ contents=contents,
3261
+ config=final_config,
3262
+ )
3263
+
3264
+ ai_message = AIMessageFactory.from_gemini(
3265
+ response=response,
3266
+ input_text=text,
3267
+ model=model_name,
3268
+ user_id=user_id,
3269
+ session_id=session_id,
3270
+ turn_id=turn_id,
3271
+ structured_output=structured_output_model,
3272
+ tool_calls=[],
3273
+ )
3274
+ ai_message.provider = "google_genai"
3275
+
3276
+ return ai_message
3277
+
3278
+ async def analyze_product_review(
3279
+ self,
3280
+ review_text: str,
3281
+ product_id: str,
3282
+ product_name: str,
3283
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
3284
+ temperature: Optional[float] = 0.1,
3285
+ user_id: Optional[str] = None,
3286
+ session_id: Optional[str] = None,
3287
+ use_structured: bool = True,
3288
+ ) -> AIMessage:
3289
+ """
3290
+ Analyze a product review and extract structured or unstructured information.
3291
+
3292
+ Args:
3293
+ review_text (str): The product review text to analyze.
3294
+ product_id (str): Unique identifier for the product.
3295
+ product_name (str): Name of the product being reviewed.
3296
+ model (Union[GoogleModel, str]): The model to use for the analysis.
3297
+ temperature (float): Sampling temperature for response generation.
3298
+ user_id (Optional[str]): Optional user identifier for tracking.
3299
+ session_id (Optional[str]): Optional session identifier for tracking.
3300
+ use_structured (bool): If True, forces a structured JSON output matching
3301
+ the ProductReview model. Defaults to True.
3302
+ """
3303
+ self.logger.info(
3304
+ f"Analyzing product review for product_id: '{product_id}'"
3305
+ )
3306
+
3307
+ model = model.value if isinstance(model, GoogleModel) else model
3308
+ turn_id = str(uuid.uuid4())
3309
+
3310
+ system_instruction = ""
3311
+ generation_config = {
3312
+ "max_output_tokens": self.max_tokens,
3313
+ "temperature": temperature or self.temperature,
3314
+ }
3315
+ structured_output_model = None
3316
+
3317
+ if use_structured:
3318
+ # Generate a prompt to force JSON output matching the Pydantic schema
3319
+ schema = ProductReview.model_json_schema()
3320
+ system_instruction = (
3321
+ "You are a product review analysis expert. Analyze the provided product review "
3322
+ "and extract the required information. Your response MUST be a valid JSON object "
3323
+ "that conforms to the following JSON Schema. Do not include any other text, "
3324
+ "explanations, or markdown formatting like ```json ... ``` around the JSON object.\n\n"
3325
+ f"JSON Schema:\n{self._json.dumps(schema)}"
3326
+ )
3327
+ # Enable Gemini's JSON mode for reliable structured output
3328
+ generation_config["response_mime_type"] = "application/json"
3329
+ structured_output_model = ProductReview
3330
+ else:
3331
+ # Generate a prompt for a more general, text-based analysis
3332
+ system_instruction = (
3333
+ "You are a product review analysis expert. Analyze the sentiment and key aspects "
3334
+ "of the following product review.\n"
3335
+ "Your response must include:\n"
3336
+ "1. Overall sentiment (Positive, Negative, or Neutral)\n"
3337
+ "2. Estimated Rating (on a scale of 1-5)\n"
3338
+ "3. Key Positive Points mentioned\n"
3339
+ "4. Key Negative Points mentioned\n"
3340
+ "5. A brief summary of the review's main points."
3341
+ )
3342
+
3343
+ # Build the user content part of the request
3344
+ user_prompt = (
3345
+ f"Product ID: {product_id}\n"
3346
+ f"Product Name: {product_name}\n"
3347
+ f"Review Text: \"{review_text}\""
3348
+ )
3349
+ contents = [{
3350
+ "role": "user",
3351
+ "parts": [{"text": user_prompt}]
3352
+ }]
3353
+
3354
+ # Finalize the generation configuration
3355
+ final_config = GenerateContentConfig(
3356
+ system_instruction={"role": "system", "parts": [{"text": system_instruction}]},
3357
+ tools=None,
3358
+ **generation_config
3359
+ )
3360
+
3361
+ # Make a stateless call to the model
3362
+ response = await self.client.aio.models.generate_content(
3363
+ model=model,
3364
+ contents=contents,
3365
+ config=final_config
3366
+ )
3367
+
3368
+ # Create the AIMessage response using the factory
3369
+ ai_message = AIMessageFactory.from_gemini(
3370
+ response=response,
3371
+ input_text=user_prompt, # Use the full prompt as input text
3372
+ model=model,
3373
+ user_id=user_id,
3374
+ session_id=session_id,
3375
+ turn_id=turn_id,
3376
+ structured_output=structured_output_model,
3377
+ tool_calls=[]
3378
+ )
3379
+ ai_message.provider = "google_genai"
3380
+
3381
+ return ai_message
3382
+
3383
+ async def image_generation(
3384
+ self,
3385
+ prompt_data: Union[str, ImageGenerationPrompt],
3386
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH_IMAGE_PREVIEW,
3387
+ temperature: Optional[float] = None,
3388
+ prompt_instruction: Optional[str] = None,
3389
+ reference_images: List[Optional[Path]] = None,
3390
+ output_directory: Optional[Path] = None,
3391
+ user_id: Optional[str] = None,
3392
+ session_id: Optional[str] = None,
3393
+ stateless: bool = True
3394
+ ) -> AIMessage:
3395
+ """
3396
+ Generates images based on a text prompt using Nano-Banana.
3397
+ """
3398
+ if isinstance(prompt_data, str):
3399
+ prompt_data = ImageGenerationPrompt(
3400
+ prompt=prompt_data,
3401
+ model=model,
3402
+ )
3403
+ if prompt_data.model:
3404
+ model = GoogleModel.GEMINI_2_5_FLASH_IMAGE_PREVIEW.value
3405
+ model = model.value if isinstance(model, GoogleModel) else model
3406
+ turn_id = str(uuid.uuid4())
3407
+ prompt_data.model = model
3408
+
3409
+ self.logger.info(
3410
+ f"Starting image generation with model: {model}"
3411
+ )
3412
+
3413
+ messages, conversation_session, _ = await self._prepare_conversation_context(
3414
+ prompt_data.prompt, None, user_id, session_id, None
3415
+ )
3416
+
3417
+ full_prompt = prompt_data.prompt
3418
+ if prompt_data.styles:
3419
+ full_prompt += ", " + ", ".join(prompt_data.styles)
3420
+
3421
+ # Prepare conversation history for Google GenAI format
3422
+ history = []
3423
+ if messages:
3424
+ for msg in messages[:-1]: # Exclude the current user message (last in list)
3425
+ role = msg['role'].lower()
3426
+ if role == 'user':
3427
+ parts = []
3428
+ for part_content in msg.get('content', []):
3429
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
3430
+ parts.append(Part(text=part_content.get('text', '')))
3431
+ if parts:
3432
+ history.append(UserContent(parts=parts))
3433
+ elif role in ['assistant', 'model']:
3434
+ parts = []
3435
+ for part_content in msg.get('content', []):
3436
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
3437
+ parts.append(Part(text=part_content.get('text', '')))
3438
+ if parts:
3439
+ history.append(ModelContent(parts=parts))
3440
+
3441
+ ref_images = []
3442
+ if reference_images:
3443
+ self.logger.info(
3444
+ f"Using reference image: {reference_images}"
3445
+ )
3446
+ for img_path in reference_images:
3447
+ if not img_path.exists():
3448
+ raise FileNotFoundError(
3449
+ f"Reference image not found: {img_path}"
3450
+ )
3451
+ # Load the reference image
3452
+ ref_images.append(Image.open(img_path))
3453
+
3454
+ config=types.GenerateContentConfig(
3455
+ response_modalities=['Text', 'Image'],
3456
+ temperature=temperature or self.temperature,
3457
+ system_instruction=prompt_instruction
3458
+ )
3459
+
3460
+ try:
3461
+ start_time = time.time()
3462
+ content = [full_prompt, *ref_images] if ref_images else [full_prompt]
3463
+ # Use the asynchronous client for image generation
3464
+ if stateless:
3465
+ response = await self.client.aio.models.generate_content(
3466
+ model=prompt_data.model,
3467
+ contents=content,
3468
+ config=config
3469
+ )
3470
+ else:
3471
+ # Create the stateful chat session
3472
+ chat = self.client.aio.chats.create(model=model, history=history, config=config)
3473
+ response = await chat.send_message(
3474
+ message=content,
3475
+ )
3476
+ execution_time = time.time() - start_time
3477
+
3478
+ pil_images = []
3479
+ saved_image_paths = []
3480
+ raw_response = {} # Initialize an empty dict for the raw response
3481
+
3482
+ raw_response['generated_images'] = []
3483
+ for part in response.candidates[0].content.parts:
3484
+ if part.text is not None:
3485
+ raw_response['text'] = part.text
3486
+ elif part.inline_data is not None:
3487
+ image = Image.open(io.BytesIO(part.inline_data.data))
3488
+ pil_images.append(image)
3489
+ if output_directory:
3490
+ if isinstance(output_directory, str):
3491
+ output_directory = Path(output_directory).resolve()
3492
+ file_path = self._save_image(image, output_directory)
3493
+ saved_image_paths.append(file_path)
3494
+ raw_response['generated_images'].append({
3495
+ 'uri': file_path,
3496
+ 'seed': None
3497
+ })
3498
+
3499
+ usage = CompletionUsage(execution_time=execution_time)
3500
+ if not stateless:
3501
+ await self._update_conversation_memory(
3502
+ user_id,
3503
+ session_id,
3504
+ conversation_session,
3505
+ messages + [
3506
+ {
3507
+ "role": "user",
3508
+ "content": [
3509
+ {"type": "text", "text": f"[Image Analysis]: {full_prompt}"}
3510
+ ]
3511
+ },
3512
+ ],
3513
+ None,
3514
+ turn_id,
3515
+ prompt_data.prompt,
3516
+ response.text,
3517
+ []
3518
+ )
3519
+ ai_message = AIMessageFactory.from_imagen(
3520
+ output=pil_images,
3521
+ images=saved_image_paths,
3522
+ input=full_prompt,
3523
+ model=model,
3524
+ user_id=user_id,
3525
+ session_id=session_id,
3526
+ provider='nano-banana',
3527
+ usage=usage,
3528
+ raw_response=raw_response
3529
+ )
3530
+ return ai_message
3531
+
3532
+ except Exception as e:
3533
+ self.logger.error(f"Image generation failed: {e}")
3534
+ raise
3535
+
3536
+ def _upload_video(self, video_path: Union[str, Path]) -> str:
3537
+ """
3538
+ Uploads a video file to Google GenAi Client.
3539
+ """
3540
+ if isinstance(video_path, str):
3541
+ video_path = Path(video_path).resolve()
3542
+ if not video_path.exists():
3543
+ raise FileNotFoundError(
3544
+ f"Video file not found: {video_path}"
3545
+ )
3546
+ video_file = self.client.files.upload(
3547
+ file=video_path
3548
+ )
3549
+ while video_file.state == "PROCESSING":
3550
+ time.sleep(10)
3551
+ video_file = self.client.files.get(name=video_file.name)
3552
+
3553
+ if video_file.state == "FAILED":
3554
+ raise ValueError(video_file.state)
3555
+
3556
+ self.logger.debug(
3557
+ f"Uploaded video file: {video_file.uri}"
3558
+ )
3559
+
3560
+ return video_file
3561
+
3562
+ async def video_understanding(
3563
+ self,
3564
+ prompt: str,
3565
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
3566
+ temperature: Optional[float] = None,
3567
+ prompt_instruction: Optional[str] = None,
3568
+ video: Optional[Union[str, Path]] = None,
3569
+ user_id: Optional[str] = None,
3570
+ session_id: Optional[str] = None,
3571
+ stateless: bool = True,
3572
+ offsets: Optional[tuple[str, str]] = None,
3573
+ ) -> AIMessage:
3574
+ """
3575
+ Using a video (local or youtube) no analyze and extract information from videos.
3576
+ """
3577
+ model = model.value if isinstance(model, GoogleModel) else model
3578
+ turn_id = str(uuid.uuid4())
3579
+
3580
+ self.logger.info(
3581
+ f"Starting video analysis with model: {model}"
3582
+ )
3583
+
3584
+ if stateless:
3585
+ # For stateless mode, skip conversation memory
3586
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
3587
+ conversation_history = None
3588
+ else:
3589
+ # Use the unified conversation context preparation from AbstractClient
3590
+ messages, conversation_history, prompt_instruction = await self._prepare_conversation_context(
3591
+ prompt, None, user_id, session_id, prompt_instruction, stateless=stateless
3592
+ )
3593
+
3594
+ # Prepare conversation history for Google GenAI format
3595
+ history = []
3596
+ if messages:
3597
+ for msg in messages[:-1]: # Exclude the current user message (last in list)
3598
+ role = msg['role'].lower()
3599
+ if role == 'user':
3600
+ parts = []
3601
+ for part_content in msg.get('content', []):
3602
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
3603
+ parts.append(Part(text=part_content.get('text', '')))
3604
+ if parts:
3605
+ history.append(UserContent(parts=parts))
3606
+ elif role in ['assistant', 'model']:
3607
+ parts = []
3608
+ for part_content in msg.get('content', []):
3609
+ if isinstance(part_content, dict) and part_content.get('type') == 'text':
3610
+ parts.append(Part(text=part_content.get('text', '')))
3611
+ if parts:
3612
+ history.append(ModelContent(parts=parts))
3613
+
3614
+ config=types.GenerateContentConfig(
3615
+ response_modalities=['Text'],
3616
+ temperature=temperature or self.temperature,
3617
+ system_instruction=prompt_instruction
3618
+ )
3619
+
3620
+ if isinstance(video, str) and video.startswith("http"):
3621
+ # youtube video link:
3622
+ data = types.FileData(
3623
+ file_uri=video
3624
+ )
3625
+ video_metadata = None
3626
+ if offsets:
3627
+ video_metadata=types.VideoMetadata(
3628
+ start_offset=offsets[0],
3629
+ end_offset=offsets[1]
3630
+ )
3631
+ video_info = types.Part(
3632
+ file_data=data,
3633
+ video_metadata=video_metadata
3634
+ )
3635
+ else:
3636
+ video_info = self._upload_video(video)
3637
+
3638
+ try:
3639
+ start_time = time.time()
3640
+ content = [
3641
+ types.Part(
3642
+ text=prompt
3643
+ ),
3644
+ video_info
3645
+ ]
3646
+ # Use the asynchronous client for image generation
3647
+ if stateless:
3648
+ response = await self.client.aio.models.generate_content(
3649
+ model=model,
3650
+ contents=content,
3651
+ config=config
3652
+ )
3653
+ else:
3654
+ # Create the stateful chat session
3655
+ chat = self.client.aio.chats.create(model=model, history=history, config=config)
3656
+ response = await chat.send_message(
3657
+ message=content,
3658
+ )
3659
+ execution_time = time.time() - start_time
3660
+
3661
+ final_response = response.text
3662
+
3663
+ usage = CompletionUsage(execution_time=execution_time)
3664
+
3665
+ if not stateless:
3666
+ await self._update_conversation_memory(
3667
+ user_id,
3668
+ session_id,
3669
+ conversation_history,
3670
+ messages + [
3671
+ {
3672
+ "role": "user",
3673
+ "content": [
3674
+ {"type": "text", "text": f"[Image Analysis]: {prompt}"}
3675
+ ]
3676
+ },
3677
+ ],
3678
+ None,
3679
+ turn_id,
3680
+ prompt,
3681
+ final_response,
3682
+ []
3683
+ )
3684
+ # Create AIMessage using factory
3685
+ ai_message = AIMessageFactory.from_gemini(
3686
+ response=response,
3687
+ input_text=prompt,
3688
+ model=model,
3689
+ user_id=user_id,
3690
+ session_id=session_id,
3691
+ turn_id=turn_id,
3692
+ structured_output=final_response,
3693
+ tool_calls=None,
3694
+ conversation_history=conversation_history,
3695
+ text_response=final_response
3696
+ )
3697
+
3698
+ # Override provider to distinguish from Vertex AI
3699
+ ai_message.provider = "google_genai"
3700
+
3701
+ return ai_message
3702
+
3703
+ except Exception as e:
3704
+ self.logger.error(f"Image generation failed: {e}")
3705
+ raise
3706
+
3707
+ def _get_image_from_input(self, image: Union[str, Path, Image.Image]) -> Image.Image:
3708
+ """Helper to consistently load an image into a PIL object."""
3709
+ if isinstance(image, (str, Path)):
3710
+ return Image.open(image).convert("RGB")
3711
+ elif isinstance(image, bytes):
3712
+ return Image.open(io.BytesIO(image)).convert("RGB")
3713
+ else:
3714
+ return image.convert("RGB")
3715
+
3716
+ def _crop_box(self, pil_img: Image.Image, box: DetectionBox) -> Image.Image:
3717
+ """Crops a detection box from a PIL image with a small padding."""
3718
+ # A small padding can provide more context to the model
3719
+ pad = 8
3720
+ x1 = max(0, box.x1 - pad)
3721
+ y1 = max(0, box.y1 - pad)
3722
+ x2 = min(pil_img.width, box.x2 + pad)
3723
+ y2 = min(pil_img.height, box.y2 + pad)
3724
+ return pil_img.crop((x1, y1, x2, y2))
3725
+
3726
+ def _shelf_and_position(self, box: DetectionBox, regions: List[ShelfRegion]) -> Tuple[str, str]:
3727
+ """
3728
+ Determines the shelf and position for a given detection box using a robust
3729
+ centroid-based assignment logic.
3730
+ """
3731
+ if not regions:
3732
+ return "unknown", "center"
3733
+
3734
+ # --- NEW LOGIC: Use the object's center point for assignment ---
3735
+ center_y = box.y1 + (box.y2 - box.y1) / 2
3736
+ best_region = None
3737
+
3738
+ # 1. Primary Method: Find which shelf region CONTAINS the center point.
3739
+ for region in regions:
3740
+ if region.bbox.y1 <= center_y < region.bbox.y2:
3741
+ best_region = region
3742
+ break # Found the correct shelf
3743
+
3744
+ # 2. Fallback Method: If no shelf contains the center (edge case), find the closest one.
3745
+ if not best_region:
3746
+ min_distance = float('inf')
3747
+ for region in regions:
3748
+ shelf_center_y = region.bbox.y1 + (region.bbox.y2 - region.bbox.y1) / 2
3749
+ distance = abs(center_y - shelf_center_y)
3750
+ if distance < min_distance:
3751
+ min_distance = distance
3752
+ best_region = region
3753
+
3754
+ shelf = best_region.level if best_region else "unknown"
3755
+
3756
+ # --- Position logic remains the same, it's correct ---
3757
+ if best_region:
3758
+ box_center_x = (box.x1 + box.x2) / 2.0
3759
+ shelf_width = best_region.bbox.x2 - best_region.bbox.x1
3760
+ third_width = shelf_width / 3.0
3761
+ left_boundary = best_region.bbox.x1 + third_width
3762
+ right_boundary = best_region.bbox.x1 + 2 * third_width
3763
+
3764
+ if box_center_x < left_boundary:
3765
+ position = "left"
3766
+ elif box_center_x > right_boundary:
3767
+ position = "right"
3768
+ else:
3769
+ position = "center"
3770
+ else:
3771
+ position = "center"
3772
+
3773
+ return shelf, position
3774
+
3775
+ async def image_identification(
3776
+ self,
3777
+ prompt: str,
3778
+ image: Union[Path, bytes, Image.Image],
3779
+ detections: List[DetectionBox],
3780
+ shelf_regions: List[ShelfRegion],
3781
+ reference_images: Optional[Dict[str, Union[Path, bytes, Image.Image]]] = None,
3782
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_PRO,
3783
+ temperature: float = 0.0,
3784
+ user_id: Optional[str] = None,
3785
+ session_id: Optional[str] = None,
3786
+ ) -> List[IdentifiedProduct]:
3787
+ """
3788
+ Identify products using detected boxes, reference images, and Gemini Vision.
3789
+
3790
+ This method sends the full image, reference images, and individual crops of each
3791
+ detection to Gemini for precise identification, returning a structured list of
3792
+ IdentifiedProduct objects.
3793
+
3794
+ Args:
3795
+ image: The main image of the retail display.
3796
+ detections: A list of `DetectionBox` objects from the initial detection step.
3797
+ shelf_regions: A list of `ShelfRegion` objects defining shelf boundaries.
3798
+ reference_images: Optional list of images showing ideal products.
3799
+ model: The Gemini model to use, defaulting to Gemini 2.5 Pro for its advanced vision capabilities.
3800
+ temperature: The sampling temperature for the model's response.
3801
+
3802
+ Returns:
3803
+ A list of `IdentifiedProduct` objects with detailed identification info.
3804
+ """
3805
+ self.logger.info(f"Starting Gemini identification for {len(detections)} detections.")
3806
+ model_name = model.value if isinstance(model, GoogleModel) else model
3807
+
3808
+ # --- 1. Prepare Images and Metadata ---
3809
+ main_image_pil = self._get_image_from_input(image)
3810
+ detection_details = []
3811
+ id_to_details = {}
3812
+ for i, det in enumerate(detections, start=1):
3813
+ shelf, pos = self._shelf_and_position(det, shelf_regions)
3814
+ detection_details.append({
3815
+ "id": i,
3816
+ "detection": det,
3817
+ "shelf": shelf,
3818
+ "position": pos,
3819
+ "crop": self._crop_box(main_image_pil, det),
3820
+ })
3821
+ id_to_details[i] = {"shelf": shelf, "position": pos, "detection": det}
3822
+
3823
+ # --- 2. Construct the Multi-Modal Prompt for Gemini ---
3824
+ # The prompt is a list of parts: text instructions, reference images,
3825
+ # the main image, and finally the individual crops.
3826
+ contents = [Part(text=prompt)] # Start with the user-provided prompt
3827
+
3828
+ # --- Create a lookup map from ID to pre-calculated details ---
3829
+ id_to_details = {}
3830
+ for i, det in enumerate(detections, 1):
3831
+ shelf, pos = self._shelf_and_position(det, shelf_regions)
3832
+ id_to_details[i] = {"shelf": shelf, "position": pos, "detection": det}
3833
+
3834
+ if reference_images:
3835
+ # Add a text part to introduce the references
3836
+ contents.append(Part(text="\n\n--- REFERENCE IMAGE GUIDE ---"))
3837
+ for label, ref_img_input in reference_images.items():
3838
+ # Add the label text, then the image
3839
+ contents.append(Part(text=f"Reference for '{label}':"))
3840
+ contents.append(self._get_image_from_input(ref_img_input))
3841
+ contents.append(Part(text="--- END REFERENCE GUIDE ---"))
3842
+
3843
+ # Add the main image for overall context
3844
+ contents.append(main_image_pil)
3845
+
3846
+ # Add each cropped detection image
3847
+ for item in detection_details:
3848
+ contents.append(item['crop'])
3849
+
3850
+ for i, det in enumerate(detections, 1):
3851
+ contents.append(self._crop_box(main_image_pil, det))
3852
+
3853
+ # Manually generate the JSON schema from the Pydantic model
3854
+ raw_schema = IdentificationResponse.model_json_schema()
3855
+ # Clean the schema to remove unsupported properties like 'additionalProperties'
3856
+ _schema = self.clean_google_schema(raw_schema)
3857
+
3858
+ # --- 3. Configure the API Call for Structured Output ---
3859
+ generation_config = GenerateContentConfig(
3860
+ temperature=temperature,
3861
+ max_output_tokens=8192, # Generous limit for JSON with many items
3862
+ response_mime_type="application/json",
3863
+ response_schema=_schema,
3864
+ )
3865
+
3866
+ # --- 4. Call Gemini and Process the Response ---
3867
+ try:
3868
+ response = await self.client.aio.models.generate_content(
3869
+ model=model_name,
3870
+ contents=contents,
3871
+ config=generation_config,
3872
+ )
3873
+ except Exception as e:
3874
+ # if is 503 UNAVAILABLE. {'error': {'code': 503, 'message': 'The model is overloaded. Please try again later.', 'status': 'UNAVAILABLE'}}
3875
+ # then, retry with a short delay but chaing to use gemini-2,5-flash instead pro.
3876
+ await asyncio.sleep(1.5)
3877
+ response = await self.client.aio.models.generate_content(
3878
+ model='gemini-2.5-flash',
3879
+ contents=contents,
3880
+ config=generation_config,
3881
+ )
3882
+
3883
+ try:
3884
+ response_text = self._safe_extract_text(response)
3885
+ if not response_text:
3886
+ raise ValueError(
3887
+ "Received an empty response from the model."
3888
+ )
3889
+
3890
+ print('RAW RESPONSE:', response_text)
3891
+ # The model output should conform to the Pydantic model directly
3892
+ parsed_data = IdentificationResponse.model_validate_json(response_text)
3893
+ identified_items = parsed_data.identified_products
3894
+
3895
+ # --- 5. Link LLM results back to original detections ---
3896
+ final_products = []
3897
+ for item in identified_items:
3898
+ # Case 1: Item was pre-detected (has a positive ID)
3899
+ if item.detection_id is not None and item.detection_id > 0 and item.detection_id in id_to_details:
3900
+ details = id_to_details[item.detection_id]
3901
+ item.detection_box = details["detection"]
3902
+
3903
+ # Only use geometric fallback if LLM didn't provide shelf_location
3904
+ if not item.shelf_location:
3905
+ self.logger.warning(
3906
+ f"LLM did not provide shelf_location for ID {item.detection_id}. Using geometric fallback."
3907
+ )
3908
+ item.shelf_location = details["shelf"]
3909
+ if not item.position_on_shelf:
3910
+ item.position_on_shelf = details["position"]
3911
+ final_products.append(item)
3912
+
3913
+ # Case 2: Item was newly found by the LLM
3914
+ elif item.detection_id is None:
3915
+ if item.detection_box:
3916
+ # TRUST the LLM's assignment, only use geometric fallback if missing
3917
+ if not item.shelf_location:
3918
+ self.logger.info(f"LLM didn't provide shelf_location, calculating geometrically")
3919
+ shelf, pos = self._shelf_and_position(item.detection_box, shelf_regions)
3920
+ item.shelf_location = shelf
3921
+ item.position_on_shelf = pos
3922
+ else:
3923
+ # LLM provided shelf_location, trust it but calculate position if missing
3924
+ self.logger.info(f"Using LLM-assigned shelf_location: {item.shelf_location}")
3925
+ if not item.position_on_shelf:
3926
+ _, pos = self._shelf_and_position(item.detection_box, shelf_regions)
3927
+ item.position_on_shelf = pos
3928
+
3929
+ self.logger.info(
3930
+ f"Adding new object found by LLM: {item.product_type} on shelf '{item.shelf_location}'"
3931
+ )
3932
+ final_products.append(item)
3933
+
3934
+ # Case 3: Item was newly found by the LLM (has a negative ID from our validator)
3935
+ elif item.detection_id < 0:
3936
+ if item.detection_box:
3937
+ # TRUST the LLM's assignment, only use geometric fallback if missing
3938
+ if not item.shelf_location:
3939
+ self.logger.info(f"LLM didn't provide shelf_location, calculating geometrically")
3940
+ shelf, pos = self._shelf_and_position(item.detection_box, shelf_regions)
3941
+ item.shelf_location = shelf
3942
+ item.position_on_shelf = pos
3943
+ else:
3944
+ # LLM provided shelf_location, trust it but calculate position if missing
3945
+ self.logger.info(f"Using LLM-assigned shelf_location: {item.shelf_location}")
3946
+ if not item.position_on_shelf:
3947
+ _, pos = self._shelf_and_position(item.detection_box, shelf_regions)
3948
+ item.position_on_shelf = pos
3949
+
3950
+ self.logger.info(f"Adding new object found by LLM: {item.product_type} on shelf '{item.shelf_location}'")
3951
+ final_products.append(item)
3952
+ else:
3953
+ self.logger.warning(
3954
+ f"LLM-found item with ID '{item.detection_id}' is missing a detection_box, skipping."
3955
+ )
3956
+
3957
+ self.logger.info(
3958
+ f"Successfully identified {len(final_products)} products."
3959
+ )
3960
+ return final_products
3961
+
3962
+ except Exception as e:
3963
+ self.logger.error(
3964
+ f"Gemini image identification failed: {e}"
3965
+ )
3966
+ # Fallback to creating simple products from initial detections
3967
+ fallback_products = []
3968
+ for item in detection_details:
3969
+ shelf, pos = item["shelf"], item["position"]
3970
+ det = item["detection"]
3971
+ fallback_products.append(IdentifiedProduct(
3972
+ detection_box=det,
3973
+ detection_id=item['id'],
3974
+ product_type=det.class_name,
3975
+ product_model=None,
3976
+ confidence=det.confidence * 0.5, # Lower confidence for fallback
3977
+ visual_features=["fallback_identification"],
3978
+ reference_match="none",
3979
+ shelf_location=shelf,
3980
+ position_on_shelf=pos
3981
+ ))
3982
+ return fallback_products
3983
+
3984
+ async def create_speech(
3985
+ self,
3986
+ content: str,
3987
+ voice_name: Optional[str] = 'charon',
3988
+ model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
3989
+ output_directory: Optional[Path] = None,
3990
+ only_script: bool = False,
3991
+ script_file: str = "narration_script.txt",
3992
+ podcast_file: str= "generated_podcast.wav",
3993
+ mime_format: str = "audio/wav",
3994
+ user_id: Optional[str] = None,
3995
+ session_id: Optional[str] = None,
3996
+ max_retries: int = 3,
3997
+ retry_delay: float = 1.0,
3998
+ language: str = "en-US"
3999
+ ) -> AIMessage:
4000
+ """
4001
+ Generates a simple narrative script from text and then converts it to speech.
4002
+ This is a simpler, two-step process for text-to-speech generation.
4003
+
4004
+ Args:
4005
+ content (str): The text content to generate speech from.
4006
+ voice_name (Optional[str]): The name of the voice to use. Defaults to 'charon'.
4007
+ model (Union[str, GoogleModel]): The model for the text-to-text step.
4008
+ output_directory (Optional[Path]): Directory to save the audio file.
4009
+ mime_format (str): The audio format, e.g., 'audio/wav'.
4010
+ user_id (Optional[str]): Optional user identifier.
4011
+ session_id (Optional[str]): Optional session identifier.
4012
+ max_retries (int): Maximum network retries.
4013
+ retry_delay (float): Delay for retries.
4014
+
4015
+ Returns:
4016
+ An AIMessage object containing the generated audio, the text script, and metadata.
4017
+ """
4018
+ self.logger.info(
4019
+ "Starting a two-step text-to-speech process."
4020
+ )
4021
+ # Step 1: Generate a simple, narrated script from the provided text.
4022
+ system_prompt = f"""
4023
+ You are a professional scriptwriter. Given the input text, generate a clear, narrative style, suitable for a voiceover.
4024
+
4025
+ **Instructions:**
4026
+ - The conversation should be engaging, natural, and suitable for a TTS system.
4027
+ - The script should be formatted for TTS, with clear speaker lines.
4028
+ """
4029
+ script_prompt = f"""
4030
+ Read the following text in a clear, narrative style, suitable for a voiceover.
4031
+ Ensure the tone is neutral and professional. Do not add any conversational
4032
+ elements. Just read the text.
4033
+
4034
+ Text:
4035
+ ---
4036
+ {content}
4037
+ ---
4038
+ """
4039
+ script_text = ''
4040
+ script_response = None
4041
+ try:
4042
+ script_response = await self.ask(
4043
+ prompt=script_prompt,
4044
+ model=model,
4045
+ system_prompt=system_prompt,
4046
+ temperature=0.0,
4047
+ stateless=True,
4048
+ use_tools=False,
4049
+ )
4050
+ script_text = script_response.output
4051
+ except Exception as e:
4052
+ self.logger.error(f"Script generation failed: {e}")
4053
+ raise SpeechGenerationError(
4054
+ f"Script generation failed: {str(e)}"
4055
+ ) from e
4056
+
4057
+ if not script_text:
4058
+ raise SpeechGenerationError(
4059
+ "Script generation failed, could not proceed with speech generation."
4060
+ )
4061
+
4062
+ self.logger.info(f"Generated script text successfully.")
4063
+ saved_file_paths = []
4064
+ if only_script:
4065
+ # If only the script is needed, save it and return it in an AIMessage
4066
+ output_directory.mkdir(parents=True, exist_ok=True)
4067
+ script_path = output_directory / script_file
4068
+ try:
4069
+ async with aiofiles.open(script_path, "w", encoding="utf-8") as f:
4070
+ await f.write(script_text)
4071
+ self.logger.info(
4072
+ f"Saved narration script to {script_path}"
4073
+ )
4074
+ saved_file_paths.append(script_path)
4075
+ except Exception as e:
4076
+ self.logger.error(f"Failed to save script file: {e}")
4077
+ ai_message = AIMessageFactory.from_gemini(
4078
+ response=script_response,
4079
+ text_response=script_text,
4080
+ input_text=content,
4081
+ model=model if isinstance(model, str) else model.value,
4082
+ user_id=user_id,
4083
+ session_id=session_id,
4084
+ files=saved_file_paths
4085
+ )
4086
+ return ai_message
4087
+
4088
+ # Step 2: Generate speech from the generated script.
4089
+ speech_config_data = SpeechGenerationPrompt(
4090
+ prompt=script_text,
4091
+ speakers=[
4092
+ SpeakerConfig(
4093
+ name="narrator",
4094
+ voice=voice_name,
4095
+ )
4096
+ ],
4097
+ language=language
4098
+ )
4099
+
4100
+ # Use the existing core logic to generate the audio
4101
+ model = GoogleModel.GEMINI_2_5_FLASH_TTS.value
4102
+
4103
+ speaker = speech_config_data.speakers[0]
4104
+ final_voice = speaker.voice
4105
+
4106
+ speech_config = types.SpeechConfig(
4107
+ voice_config=types.VoiceConfig(
4108
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
4109
+ voice_name=final_voice
4110
+ )
4111
+ ),
4112
+ language_code=speech_config_data.language or "en-US"
4113
+ )
4114
+
4115
+ config = types.GenerateContentConfig(
4116
+ response_modalities=["AUDIO"],
4117
+ speech_config=speech_config,
4118
+ temperature=0.7
4119
+ )
4120
+
4121
+ for attempt in range(max_retries + 1):
4122
+ try:
4123
+ if attempt > 0:
4124
+ delay = retry_delay * (2 ** (attempt - 1))
4125
+ self.logger.info(
4126
+ f"Retrying speech (attempt {attempt + 1}/{max_retries + 1}) after {delay}s delay..."
4127
+ )
4128
+ await asyncio.sleep(delay)
4129
+ start_time = time.time()
4130
+ response = await self.client.aio.models.generate_content(
4131
+ model=model,
4132
+ contents=speech_config_data.prompt,
4133
+ config=config,
4134
+ )
4135
+ execution_time = time.time() - start_time
4136
+ audio_data = self._extract_audio_data(response)
4137
+ if audio_data is None:
4138
+ raise SpeechGenerationError(
4139
+ "No audio data found in response. The speech generation may have failed."
4140
+ )
4141
+
4142
+ saved_file_paths = []
4143
+ if output_directory:
4144
+ output_directory.mkdir(parents=True, exist_ok=True)
4145
+ podcast_path = output_directory / podcast_file
4146
+ script_path = output_directory / script_file
4147
+ self._save_audio_file(audio_data, podcast_path, mime_format)
4148
+ saved_file_paths.append(podcast_path)
4149
+ try:
4150
+ async with aiofiles.open(script_path, "w", encoding="utf-8") as f:
4151
+ await f.write(script_text)
4152
+ self.logger.info(f"Saved narration script to {script_path}")
4153
+ saved_file_paths.append(script_path)
4154
+ except Exception as e:
4155
+ self.logger.error(f"Failed to save script file: {e}")
4156
+
4157
+ usage = CompletionUsage(
4158
+ execution_time=execution_time,
4159
+ input_tokens=len(script_text),
4160
+ )
4161
+
4162
+ ai_message = AIMessageFactory.from_speech(
4163
+ output=audio_data,
4164
+ files=saved_file_paths,
4165
+ input=script_text,
4166
+ model=model,
4167
+ provider="google_genai",
4168
+ documents=[script_path],
4169
+ usage=usage,
4170
+ user_id=user_id,
4171
+ session_id=session_id,
4172
+ raw_response=None
4173
+ )
4174
+ return ai_message
4175
+
4176
+ except (
4177
+ aiohttp.ClientPayloadError,
4178
+ aiohttp.ClientConnectionError,
4179
+ aiohttp.ClientResponseError,
4180
+ aiohttp.ServerTimeoutError,
4181
+ ConnectionResetError,
4182
+ TimeoutError,
4183
+ asyncio.TimeoutError
4184
+ ) as network_error:
4185
+ if attempt < max_retries:
4186
+ self.logger.warning(
4187
+ f"Network error on attempt {attempt + 1}: {str(network_error)}. Retrying..."
4188
+ )
4189
+ continue
4190
+ else:
4191
+ self.logger.error(
4192
+ f"Speech generation failed after {max_retries + 1} attempts"
4193
+ )
4194
+ raise SpeechGenerationError(
4195
+ f"Speech generation failed after {max_retries + 1} attempts. "
4196
+ f"Last error: {str(network_error)}."
4197
+ ) from network_error
4198
+
4199
+ except Exception as e:
4200
+ self.logger.error(
4201
+ f"Speech generation failed with non-retryable error: {str(e)}"
4202
+ )
4203
+ raise SpeechGenerationError(
4204
+ f"Speech generation failed: {str(e)}"
4205
+ ) from e
4206
+
4207
+ async def video_generation(
4208
+ self,
4209
+ prompt_data: Union[str, VideoGenerationPrompt],
4210
+ model: Union[str, GoogleModel] = GoogleModel.VEO_3_0,
4211
+ reference_image: Optional[Path] = None,
4212
+ generate_image_first: bool = False,
4213
+ image_prompt: Optional[str] = None,
4214
+ image_generation_model: str = "imagen-4.0-generate-001",
4215
+ aspect_ratio: Optional[str] = None,
4216
+ resolution: Optional[str] = None,
4217
+ negative_prompt: Optional[str] = None,
4218
+ output_directory: Optional[Path] = None,
4219
+ user_id: Optional[str] = None,
4220
+ session_id: Optional[str] = None,
4221
+ stateless: bool = True,
4222
+ poll_interval: int = 10
4223
+ ) -> AIMessage:
4224
+ """
4225
+ Generates videos based on a text prompt using Veo models.
4226
+
4227
+ Args:
4228
+ prompt_data: Text prompt or VideoGenerationPrompt object
4229
+ model: Video generation model (VEO_2_0 or VEO_3_0)
4230
+ reference_image: Optional path to reference image. If provided, this takes precedence.
4231
+ generate_image_first: If True and no reference_image, generates an image with Imagen first
4232
+ image_generation_model: Model to use for image generation (default: imagen-4.0-generate-001)
4233
+ aspect_ratio: Video aspect ratio (e.g., "16:9", "9:16"). Overrides prompt_data setting.
4234
+ resolution: Video resolution (e.g., "720p", "1080p"). Overrides prompt_data setting.
4235
+ negative_prompt: What to avoid in the video. Overrides prompt_data setting.
4236
+ output_directory: Directory to save generated videos
4237
+ user_id: User ID for conversation tracking
4238
+ session_id: Session ID for conversation tracking
4239
+ stateless: If True, no conversation memory is saved
4240
+ poll_interval: Seconds between polling checks (default: 10)
4241
+
4242
+ Returns:
4243
+ AIMessage containing the generated video
4244
+ """
4245
+ # Parse prompt data
4246
+ if isinstance(prompt_data, str):
4247
+ prompt_data = VideoGenerationPrompt(
4248
+ prompt=prompt_data,
4249
+ model=model.value if isinstance(model, GoogleModel) else model,
4250
+ )
4251
+
4252
+ # Validate and set model
4253
+ if prompt_data.model:
4254
+ model = prompt_data.model
4255
+ model = model.value if isinstance(model, GoogleModel) else model
4256
+
4257
+ if model not in [GoogleModel.VEO_2_0.value, GoogleModel.VEO_3_0.value, GoogleModel.VEO_3_0_FAST.value]:
4258
+ raise ValueError(
4259
+ f"Video generation only supported with VEO 2.0 or VEO 3.0 models. Got: {model}"
4260
+ )
4261
+
4262
+ # Setup output directory
4263
+ if output_directory:
4264
+ if isinstance(output_directory, str):
4265
+ output_directory = Path(output_directory).resolve()
4266
+ output_directory.mkdir(parents=True, exist_ok=True)
4267
+ else:
4268
+ output_directory = BASE_DIR.joinpath('static', 'generated_videos')
4269
+ output_directory.mkdir(parents=True, exist_ok=True)
4270
+
4271
+ turn_id = str(uuid.uuid4())
4272
+
4273
+ self.logger.info(
4274
+ f"Starting video generation with model: {model}"
4275
+ )
4276
+
4277
+ # Prepare conversation context if not stateless
4278
+ if not stateless:
4279
+ messages, conversation_session, _ = await self._prepare_conversation_context(
4280
+ prompt_data.prompt, None, user_id, session_id, None
4281
+ )
4282
+ else:
4283
+ messages = None
4284
+ conversation_session = None
4285
+
4286
+ # Override prompt settings with explicit parameters
4287
+ final_aspect_ratio = aspect_ratio or prompt_data.aspect_ratio or "16:9"
4288
+ final_resolution = resolution or getattr(prompt_data, 'resolution', None) or "720p"
4289
+ final_negative_prompt = negative_prompt or prompt_data.negative_prompt or ""
4290
+
4291
+ # Step 1: Handle image input (reference or generate)
4292
+ generated_image = None
4293
+ image_for_video = None
4294
+
4295
+ if reference_image:
4296
+ self.logger.info(
4297
+ f"Using reference image: {reference_image}"
4298
+ )
4299
+ if not reference_image.exists():
4300
+ raise FileNotFoundError(f"Reference image not found: {reference_image}")
4301
+
4302
+ # VEO 3.0 doesn't support reference images, fall back to VEO 2.0
4303
+ # if model == GoogleModel.VEO_3_0.value:
4304
+ # self.logger.warning(
4305
+ # "VEO 3.0 does not support reference images. Switching to VEO 2.0."
4306
+ # )
4307
+ # model = GoogleModel.VEO_3_0_FAST
4308
+
4309
+ # Load reference image
4310
+ ref_image_pil = Image.open(reference_image)
4311
+ # Convert PIL Image to bytes for Google GenAI API
4312
+ img_byte_arr = io.BytesIO()
4313
+ ref_image_pil.save(img_byte_arr, format=ref_image_pil.format or 'JPEG')
4314
+ img_byte_arr.seek(0)
4315
+ image_bytes = img_byte_arr.getvalue()
4316
+
4317
+ image_for_video = types.Image(
4318
+ image_bytes=image_bytes,
4319
+ mime_type=f"image/{(ref_image_pil.format or 'jpeg').lower()}"
4320
+ )
4321
+
4322
+ elif generate_image_first:
4323
+ self.logger.info(
4324
+ f"Generating image first with {image_generation_model} before video generation"
4325
+ )
4326
+
4327
+ try:
4328
+ # Generate image using Imagen
4329
+ image_config = types.GenerateImagesConfig(
4330
+ number_of_images=1,
4331
+ output_mime_type="image/jpeg",
4332
+ aspect_ratio=final_aspect_ratio
4333
+ )
4334
+
4335
+ gen_prompt = image_prompt or prompt_data.prompt
4336
+
4337
+ image_response = await self.client.aio.models.generate_images(
4338
+ model=image_generation_model,
4339
+ prompt=gen_prompt,
4340
+ config=image_config
4341
+ )
4342
+
4343
+ if image_response.generated_images:
4344
+ generated_image = image_response.generated_images[0]
4345
+ self.logger.info(
4346
+ "Successfully generated reference image for video"
4347
+ )
4348
+
4349
+ # Convert generated image to format needed for video generation
4350
+ pil_image = generated_image.image
4351
+ # can we use directly because is a google.genai.types.Image
4352
+ image_for_video = pil_image
4353
+ # Also, save the generated image to output directory:
4354
+ gen_image_path = output_directory / f"generated_image_{turn_id}.jpg"
4355
+ pil_image.save(gen_image_path)
4356
+ self.logger.info(
4357
+ f"Saved generated reference image to: {gen_image_path}"
4358
+ )
4359
+
4360
+ # VEO 3.0 doesn't support reference images
4361
+ if model == GoogleModel.VEO_3_0.value:
4362
+ self.logger.warning(
4363
+ "VEO 3.0 does not support reference images. Switching to VEO 3.0 FAST"
4364
+ )
4365
+ model = GoogleModel.VEO_3_0_FAST
4366
+ else:
4367
+ raise Exception("Image generation returned no images")
4368
+
4369
+ except Exception as e:
4370
+ self.logger.error(f"Image generation failed: {e}")
4371
+ raise Exception(f"Failed to generate reference image: {e}")
4372
+
4373
+ # Step 2: Generate video
4374
+ self.logger.info(f"Generating video with prompt: '{prompt_data.prompt[:100]}...'")
4375
+
4376
+ try:
4377
+ start_time = time.time()
4378
+
4379
+ # Prepare video generation arguments
4380
+ video_args = {
4381
+ "model": model,
4382
+ "prompt": prompt_data.prompt,
4383
+ }
4384
+
4385
+ if image_for_video:
4386
+ video_args["image"] = image_for_video
4387
+
4388
+ # Create config with all parameters
4389
+ video_config = types.GenerateVideosConfig(
4390
+ aspect_ratio=final_aspect_ratio,
4391
+ number_of_videos=prompt_data.number_of_videos or 1,
4392
+ )
4393
+
4394
+ # Add resolution if supported (check model capabilities)
4395
+ if final_resolution:
4396
+ video_config.resolution = final_resolution
4397
+
4398
+ # Add negative prompt if provided
4399
+ if final_negative_prompt:
4400
+ video_config.negative_prompt = final_negative_prompt
4401
+
4402
+ video_args["config"] = video_config
4403
+
4404
+ # Start async video generation operation
4405
+ self.logger.info("Starting async video generation operation...")
4406
+ operation = await self.client.aio.models.generate_videos(**video_args)
4407
+
4408
+ # Step 3: Poll operation status asynchronously
4409
+ self.logger.info(
4410
+ f"Polling video generation status every {poll_interval} seconds..."
4411
+ )
4412
+ spinner_chars = ['|', '/', '-', '\\']
4413
+ spinner_index = 0
4414
+ poll_count = 0
4415
+
4416
+ # This loop checks the job status every poll_interval seconds
4417
+ while not operation.done:
4418
+ poll_count += 1
4419
+ # This inner loop runs the spinner animation for the poll_interval
4420
+ for _ in range(poll_interval):
4421
+ # Write the spinner character to the console
4422
+ sys.stdout.write(
4423
+ f"\rVideo generation job started. Waiting for completion... {spinner_chars[spinner_index]}"
4424
+ )
4425
+ sys.stdout.flush()
4426
+ spinner_index = (spinner_index + 1) % len(spinner_chars)
4427
+ await asyncio.sleep(1) # Animate every second (async version)
4428
+
4429
+ # After poll_interval seconds, get the updated operation status
4430
+ operation = await self.client.aio.operations.get(operation)
4431
+
4432
+ print("\rVideo generation job completed. ", end="")
4433
+ sys.stdout.flush()
4434
+
4435
+ execution_time = time.time() - start_time
4436
+ self.logger.info(
4437
+ f"Video generation completed in {execution_time:.2f}s after {poll_count} polls"
4438
+ )
4439
+
4440
+ # Step 4: Download and save videos using bytes download
4441
+ generated_videos = operation.response.generated_videos
4442
+
4443
+ if not generated_videos:
4444
+ raise Exception("Video generation completed but no videos were returned")
4445
+
4446
+ saved_video_paths = []
4447
+ raw_response = {'generated_videos': []}
4448
+
4449
+ for n, generated_video in enumerate(generated_videos):
4450
+ # Download the video bytes (MP4)
4451
+ # NOTE: Use sync client for file download as aio may not support it
4452
+ mp4_bytes = self.client.files.download(file=generated_video.video)
4453
+
4454
+ # Save video to file using helper method
4455
+ video_path = self._save_video_file(
4456
+ mp4_bytes,
4457
+ output_directory,
4458
+ video_number=n,
4459
+ mime_format='video/mp4'
4460
+ )
4461
+ saved_video_paths.append(str(video_path))
4462
+
4463
+ self.logger.info(f"Saved video to: {video_path}")
4464
+
4465
+ # Collect metadata
4466
+ raw_response['generated_videos'].append({
4467
+ 'path': str(video_path),
4468
+ 'duration': getattr(generated_video, 'duration', None),
4469
+ 'uri': getattr(generated_video, 'uri', None),
4470
+ })
4471
+
4472
+ # Step 5: Update conversation memory if not stateless
4473
+ usage = CompletionUsage(
4474
+ execution_time=execution_time,
4475
+ # Video API does not return token counts, use approximation
4476
+ input_tokens=len(prompt_data.prompt),
4477
+ )
4478
+
4479
+ if not stateless and conversation_session:
4480
+ await self._update_conversation_memory(
4481
+ user_id,
4482
+ session_id,
4483
+ conversation_session,
4484
+ messages + [
4485
+ {
4486
+ "role": "user",
4487
+ "content": [
4488
+ {"type": "text", "text": f"[Video Generation]: {prompt_data.prompt}"}
4489
+ ]
4490
+ },
4491
+ ],
4492
+ None,
4493
+ turn_id,
4494
+ prompt_data.prompt,
4495
+ f"Generated {len(saved_video_paths)} video(s)",
4496
+ []
4497
+ )
4498
+
4499
+ # Step 6: Create and return AIMessage using the factory
4500
+ ai_message = AIMessageFactory.from_video(
4501
+ output=operation, # The raw operation response object
4502
+ files=saved_video_paths, # List of saved video file paths
4503
+ input=prompt_data.prompt,
4504
+ model=model,
4505
+ provider="google_genai",
4506
+ usage=usage,
4507
+ user_id=user_id,
4508
+ session_id=session_id,
4509
+ raw_response=None # Response object isn't easily serializable
4510
+ )
4511
+
4512
+ # Add metadata about the generation
4513
+ ai_message.metadata = {
4514
+ 'aspect_ratio': final_aspect_ratio,
4515
+ 'resolution': final_resolution,
4516
+ 'negative_prompt': final_negative_prompt,
4517
+ 'reference_image_used': reference_image is not None or generate_image_first,
4518
+ 'image_generation_used': generate_image_first,
4519
+ 'poll_count': poll_count,
4520
+ 'execution_time': execution_time
4521
+ }
4522
+
4523
+ self.logger.info(
4524
+ f"Video generation successful: {len(saved_video_paths)} video(s) created"
4525
+ )
4526
+
4527
+ return ai_message
4528
+
4529
+ except Exception as e:
4530
+ self.logger.error(f"Video generation failed: {e}", exc_info=True)
4531
+ raise
4532
+
4533
+ def _save_video_file(
4534
+ self,
4535
+ video_bytes: bytes,
4536
+ output_directory: Path,
4537
+ video_number: int = 0,
4538
+ mime_format: str = "video/mp4"
4539
+ ) -> Path:
4540
+ """
4541
+ Helper method to save video bytes to disk.
4542
+
4543
+ Args:
4544
+ video_bytes: Raw video bytes from the API
4545
+ output_directory: Directory to save the video
4546
+ video_number: Index number for the video filename
4547
+ mime_format: MIME type of the video (default: video/mp4)
4548
+
4549
+ Returns:
4550
+ Path to saved video file
4551
+ """
4552
+ # Generate filename based on timestamp and video number
4553
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4554
+ filename = f"video_{timestamp}_{video_number}.mp4"
4555
+
4556
+ video_path = output_directory / filename
4557
+
4558
+ # Write bytes to file
4559
+ with open(video_path, 'wb') as f:
4560
+ f.write(video_bytes)
4561
+
4562
+ self.logger.info(f"Saved {len(video_bytes)} bytes to {video_path}")
4563
+
4564
+ return video_path
4565
+
4566
+
4567
+ GoogleClient = GoogleGenAIClient # Alias for easier imports