ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
parrot/clients/google.py
ADDED
|
@@ -0,0 +1,4567 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import sys
|
|
3
|
+
import asyncio
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union, Tuple
|
|
6
|
+
from functools import partial
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import contextlib
|
|
11
|
+
import io
|
|
12
|
+
import uuid
|
|
13
|
+
import aiofiles
|
|
14
|
+
import aiohttp
|
|
15
|
+
from PIL import Image
|
|
16
|
+
from google import genai
|
|
17
|
+
from google.genai.types import (
|
|
18
|
+
GenerateContentConfig,
|
|
19
|
+
Part,
|
|
20
|
+
ModelContent,
|
|
21
|
+
UserContent,
|
|
22
|
+
)
|
|
23
|
+
from google.oauth2 import service_account
|
|
24
|
+
from google.genai import types
|
|
25
|
+
from navconfig import config, BASE_DIR
|
|
26
|
+
import pandas as pd
|
|
27
|
+
from sklearn.base import defaultdict
|
|
28
|
+
from .base import (
|
|
29
|
+
AbstractClient,
|
|
30
|
+
ToolDefinition,
|
|
31
|
+
RetryConfig,
|
|
32
|
+
TokenRetryMixin,
|
|
33
|
+
StreamingRetryConfig
|
|
34
|
+
)
|
|
35
|
+
from ..models import (
|
|
36
|
+
AIMessage,
|
|
37
|
+
AIMessageFactory,
|
|
38
|
+
ToolCall,
|
|
39
|
+
StructuredOutputConfig,
|
|
40
|
+
OutputFormat,
|
|
41
|
+
CompletionUsage,
|
|
42
|
+
ImageGenerationPrompt,
|
|
43
|
+
SpeakerConfig,
|
|
44
|
+
SpeechGenerationPrompt,
|
|
45
|
+
VideoGenerationPrompt,
|
|
46
|
+
ObjectDetectionResult,
|
|
47
|
+
GoogleModel,
|
|
48
|
+
TTSVoice
|
|
49
|
+
)
|
|
50
|
+
from ..tools.abstract import AbstractTool, ToolResult
|
|
51
|
+
from ..models.outputs import (
|
|
52
|
+
SentimentAnalysis,
|
|
53
|
+
ProductReview
|
|
54
|
+
)
|
|
55
|
+
from ..models.google import (
|
|
56
|
+
ALL_VOICE_PROFILES,
|
|
57
|
+
VoiceRegistry,
|
|
58
|
+
ConversationalScriptConfig,
|
|
59
|
+
FictionalSpeaker
|
|
60
|
+
)
|
|
61
|
+
from ..exceptions import SpeechGenerationError # pylint: disable=E0611
|
|
62
|
+
from ..models.detections import (
|
|
63
|
+
DetectionBox,
|
|
64
|
+
ShelfRegion,
|
|
65
|
+
IdentifiedProduct,
|
|
66
|
+
IdentificationResponse
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
logging.getLogger(
|
|
70
|
+
name='PIL.TiffImagePlugin'
|
|
71
|
+
).setLevel(logging.ERROR) # Suppress TiffImagePlugin warnings
|
|
72
|
+
logging.getLogger(
|
|
73
|
+
name='google_genai'
|
|
74
|
+
).setLevel(logging.WARNING) # Suppress GenAI warnings
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class GoogleGenAIClient(AbstractClient):
|
|
78
|
+
"""
|
|
79
|
+
Client for interacting with Google's Generative AI, with support for parallel function calling.
|
|
80
|
+
|
|
81
|
+
Only Gemini-2.5-pro works well with multi-turn function calling.
|
|
82
|
+
Supports both API Key (Gemini Developer API) and Service Account (Vertex AI).
|
|
83
|
+
"""
|
|
84
|
+
client_type: str = 'google'
|
|
85
|
+
client_name: str = 'google'
|
|
86
|
+
_default_model: str = 'gemini-2.5-flash'
|
|
87
|
+
_model_garden: bool = False
|
|
88
|
+
|
|
89
|
+
def __init__(self, vertexai: bool = False, model_garden: bool = False, **kwargs):
|
|
90
|
+
self.model_garden = model_garden
|
|
91
|
+
self.vertexai: bool = True if model_garden else vertexai
|
|
92
|
+
self.vertex_location = kwargs.get('location', config.get('VERTEX_REGION'))
|
|
93
|
+
self.vertex_project = kwargs.get('project', config.get('VERTEX_PROJECT_ID'))
|
|
94
|
+
self._credentials_file = kwargs.get('credentials_file', config.get('VERTEX_CREDENTIALS_FILE'))
|
|
95
|
+
if isinstance(self._credentials_file, str):
|
|
96
|
+
self._credentials_file = Path(self._credentials_file).expanduser()
|
|
97
|
+
self.api_key = kwargs.pop('api_key', config.get('GOOGLE_API_KEY'))
|
|
98
|
+
super().__init__(**kwargs)
|
|
99
|
+
self.max_tokens = kwargs.get('max_tokens', 8192)
|
|
100
|
+
self.client = None
|
|
101
|
+
# Create a single instance of the Voice registry
|
|
102
|
+
self.voice_db = VoiceRegistry(profiles=ALL_VOICE_PROFILES)
|
|
103
|
+
|
|
104
|
+
async def get_client(self) -> genai.Client:
|
|
105
|
+
"""Get the underlying Google GenAI client."""
|
|
106
|
+
if self.vertexai:
|
|
107
|
+
self.logger.info(
|
|
108
|
+
f"Initializing Vertex AI for project {self.vertex_project} in {self.vertex_location}"
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
if self._credentials_file and self._credentials_file.exists():
|
|
112
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
113
|
+
str(self._credentials_file)
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
credentials = None # Use default credentials
|
|
117
|
+
|
|
118
|
+
return genai.Client(
|
|
119
|
+
vertexai=True,
|
|
120
|
+
project=self.vertex_project,
|
|
121
|
+
location=self.vertex_location,
|
|
122
|
+
credentials=credentials
|
|
123
|
+
)
|
|
124
|
+
except Exception as exc:
|
|
125
|
+
self.logger.error(f"Failed to initialize Vertex AI client: {exc}")
|
|
126
|
+
raise
|
|
127
|
+
return genai.Client(
|
|
128
|
+
api_key=self.api_key
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
async def close(self):
|
|
132
|
+
if self.client:
|
|
133
|
+
with contextlib.suppress(Exception):
|
|
134
|
+
await self.client._api_client._aiohttp_session.close() # pylint: disable=E1101 # noqa
|
|
135
|
+
self.client = None
|
|
136
|
+
|
|
137
|
+
def _fix_tool_schema(self, schema: dict):
|
|
138
|
+
"""Recursively converts schema type values to uppercase for GenAI compatibility."""
|
|
139
|
+
if isinstance(schema, dict):
|
|
140
|
+
for key, value in schema.items():
|
|
141
|
+
if key == 'type' and isinstance(value, str):
|
|
142
|
+
schema[key] = value.upper()
|
|
143
|
+
else:
|
|
144
|
+
self._fix_tool_schema(value)
|
|
145
|
+
elif isinstance(schema, list):
|
|
146
|
+
for item in schema:
|
|
147
|
+
self._fix_tool_schema(item)
|
|
148
|
+
return schema
|
|
149
|
+
|
|
150
|
+
def _analyze_prompt_for_tools(self, prompt: str) -> List[str]:
|
|
151
|
+
"""
|
|
152
|
+
Analyze the prompt to determine which tools might be needed.
|
|
153
|
+
This is a placeholder for more complex logic that could analyze the prompt.
|
|
154
|
+
"""
|
|
155
|
+
prompt_lower = prompt.lower()
|
|
156
|
+
# Keywords that suggest need for built-in tools
|
|
157
|
+
search_keywords = [
|
|
158
|
+
'search',
|
|
159
|
+
'find',
|
|
160
|
+
'google',
|
|
161
|
+
'web',
|
|
162
|
+
'internet',
|
|
163
|
+
'latest',
|
|
164
|
+
'news',
|
|
165
|
+
'weather'
|
|
166
|
+
]
|
|
167
|
+
has_search_intent = any(keyword in prompt_lower for keyword in search_keywords)
|
|
168
|
+
if has_search_intent:
|
|
169
|
+
return "builtin_tools"
|
|
170
|
+
else:
|
|
171
|
+
# Mixed intent - prefer custom functions if available, otherwise builtin
|
|
172
|
+
return "custom_functions"
|
|
173
|
+
|
|
174
|
+
def _resolve_schema_refs(self, schema: dict, defs: dict = None) -> dict:
|
|
175
|
+
"""
|
|
176
|
+
Recursively resolves $ref in JSON schema by inlining definitions.
|
|
177
|
+
This is crucial for Pydantic v2 schemas used with Gemini.
|
|
178
|
+
"""
|
|
179
|
+
if defs is None:
|
|
180
|
+
defs = schema.get('$defs', schema.get('definitions', {}))
|
|
181
|
+
|
|
182
|
+
if not isinstance(schema, dict):
|
|
183
|
+
return schema
|
|
184
|
+
|
|
185
|
+
# Handle $ref
|
|
186
|
+
if '$ref' in schema:
|
|
187
|
+
ref_path = schema['$ref']
|
|
188
|
+
# Extract definition name (e.g., "#/$defs/MyModel" -> "MyModel")
|
|
189
|
+
def_name = ref_path.split('/')[-1]
|
|
190
|
+
if def_name in defs:
|
|
191
|
+
# Get the definition
|
|
192
|
+
resolved = self._resolve_schema_refs(defs[def_name], defs)
|
|
193
|
+
# Merge with any other properties in the current schema (rare but possible)
|
|
194
|
+
merged = {k: v for k, v in schema.items() if k != '$ref'}
|
|
195
|
+
merged.update(resolved)
|
|
196
|
+
return merged
|
|
197
|
+
|
|
198
|
+
# Process children
|
|
199
|
+
new_schema = {}
|
|
200
|
+
for key, value in schema.items():
|
|
201
|
+
if key == 'properties' and isinstance(value, dict):
|
|
202
|
+
new_schema[key] = {
|
|
203
|
+
k: self._resolve_schema_refs(v, defs)
|
|
204
|
+
for k, v in value.items()
|
|
205
|
+
}
|
|
206
|
+
elif key == 'items' and isinstance(value, dict):
|
|
207
|
+
new_schema[key] = self._resolve_schema_refs(value, defs)
|
|
208
|
+
elif key in ('anyOf', 'allOf', 'oneOf') and isinstance(value, list):
|
|
209
|
+
new_schema[key] = [self._resolve_schema_refs(item, defs) for item in value]
|
|
210
|
+
else:
|
|
211
|
+
new_schema[key] = value
|
|
212
|
+
|
|
213
|
+
return new_schema
|
|
214
|
+
|
|
215
|
+
def clean_google_schema(self, schema: dict) -> dict:
|
|
216
|
+
"""
|
|
217
|
+
Clean a Pydantic-generated schema for Google Function Calling compatibility.
|
|
218
|
+
NOW INCLUDES: Reference resolution.
|
|
219
|
+
"""
|
|
220
|
+
if not isinstance(schema, dict):
|
|
221
|
+
return schema
|
|
222
|
+
|
|
223
|
+
# 1. Resolve References FIRST
|
|
224
|
+
# Pydantic v2 uses $defs, v1 uses definitions
|
|
225
|
+
if '$defs' in schema or 'definitions' in schema:
|
|
226
|
+
schema = self._resolve_schema_refs(schema)
|
|
227
|
+
|
|
228
|
+
cleaned = {}
|
|
229
|
+
|
|
230
|
+
# Fields that Google Function Calling supports
|
|
231
|
+
supported_fields = {
|
|
232
|
+
'type', 'description', 'enum', 'default', 'properties',
|
|
233
|
+
'required', 'items'
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# Copy supported fields
|
|
237
|
+
for key, value in schema.items():
|
|
238
|
+
if key in supported_fields:
|
|
239
|
+
if key == 'properties':
|
|
240
|
+
cleaned[key] = {k: self.clean_google_schema(v) for k, v in value.items()}
|
|
241
|
+
elif key == 'items':
|
|
242
|
+
cleaned[key] = self.clean_google_schema(value)
|
|
243
|
+
else:
|
|
244
|
+
cleaned[key] = value
|
|
245
|
+
|
|
246
|
+
# ... [Rest of your existing type conversion logic stays the same] ...
|
|
247
|
+
if 'type' in cleaned:
|
|
248
|
+
if cleaned['type'] == 'integer':
|
|
249
|
+
cleaned['type'] = 'number' # Google prefers 'number' over 'integer'
|
|
250
|
+
elif cleaned['type'] == 'object' and 'properties' not in cleaned:
|
|
251
|
+
# Ensure objects have properties field, even if empty, to prevent confusion
|
|
252
|
+
cleaned['properties'] = {}
|
|
253
|
+
elif isinstance(cleaned['type'], list):
|
|
254
|
+
non_null_types = [t for t in cleaned['type'] if t != 'null']
|
|
255
|
+
cleaned['type'] = non_null_types[0] if non_null_types else 'string'
|
|
256
|
+
|
|
257
|
+
# Handle anyOf (union types) - Simplified for Gemini
|
|
258
|
+
if 'anyOf' in schema:
|
|
259
|
+
for option in schema['anyOf']:
|
|
260
|
+
if not isinstance(option, dict): continue
|
|
261
|
+
option_type = option.get('type')
|
|
262
|
+
if option_type and option_type != 'null':
|
|
263
|
+
cleaned['type'] = option_type
|
|
264
|
+
if option_type == 'array' and 'items' in option:
|
|
265
|
+
cleaned['items'] = self.clean_google_schema(option['items'])
|
|
266
|
+
if option_type == 'object' and 'properties' in option:
|
|
267
|
+
cleaned['properties'] = {k: self.clean_google_schema(v) for k, v in option['properties'].items()}
|
|
268
|
+
if 'required' in option:
|
|
269
|
+
cleaned['required'] = option['required']
|
|
270
|
+
break
|
|
271
|
+
if 'type' not in cleaned:
|
|
272
|
+
cleaned['type'] = 'string'
|
|
273
|
+
|
|
274
|
+
# Ensure object-like schemas always advertise an object type
|
|
275
|
+
if 'properties' in cleaned and cleaned.get('type') != 'object':
|
|
276
|
+
cleaned['type'] = 'object'
|
|
277
|
+
|
|
278
|
+
# Remove problematic fields
|
|
279
|
+
problematic_fields = {
|
|
280
|
+
'prefixItems', 'additionalItems', 'minItems', 'maxItems',
|
|
281
|
+
'minLength', 'maxLength', 'pattern', 'format', 'minimum',
|
|
282
|
+
'maximum', 'exclusiveMinimum', 'exclusiveMaximum', 'multipleOf',
|
|
283
|
+
'allOf', 'anyOf', 'oneOf', 'not', 'const', 'examples',
|
|
284
|
+
'$defs', 'definitions', '$ref', 'title', 'additionalProperties'
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
for field in problematic_fields:
|
|
288
|
+
cleaned.pop(field, None)
|
|
289
|
+
|
|
290
|
+
return cleaned
|
|
291
|
+
|
|
292
|
+
def _recursive_json_repair(self, data: Any) -> Any:
|
|
293
|
+
"""
|
|
294
|
+
Traverses a dictionary/list and attempts to parse string values
|
|
295
|
+
that look like JSON objects/lists.
|
|
296
|
+
"""
|
|
297
|
+
if isinstance(data, dict):
|
|
298
|
+
return {k: self._recursive_json_repair(v) for k, v in data.items()}
|
|
299
|
+
elif isinstance(data, list):
|
|
300
|
+
return [self._recursive_json_repair(item) for item in data]
|
|
301
|
+
elif isinstance(data, str):
|
|
302
|
+
data = data.strip()
|
|
303
|
+
# fast check if it looks like json
|
|
304
|
+
if (data.startswith('{') and data.endswith('}')) or \
|
|
305
|
+
(data.startswith('[') and data.endswith(']')):
|
|
306
|
+
try:
|
|
307
|
+
import json
|
|
308
|
+
parsed = json.loads(data)
|
|
309
|
+
# Recurse into the parsed object in case it has nested strings
|
|
310
|
+
return self._recursive_json_repair(parsed)
|
|
311
|
+
except (json.JSONDecodeError, TypeError):
|
|
312
|
+
return data
|
|
313
|
+
return data
|
|
314
|
+
|
|
315
|
+
def _apply_structured_output_schema(
|
|
316
|
+
self,
|
|
317
|
+
generation_config: Dict[str, Any],
|
|
318
|
+
output_config: Optional[StructuredOutputConfig]
|
|
319
|
+
) -> Optional[Dict[str, Any]]:
|
|
320
|
+
"""Apply a cleaned structured output schema to the generationho config."""
|
|
321
|
+
if not output_config or output_config.format != OutputFormat.JSON:
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
raw_schema = output_config.get_schema()
|
|
326
|
+
cleaned_schema = self.clean_google_schema(raw_schema)
|
|
327
|
+
fixed_schema = self._fix_tool_schema(cleaned_schema)
|
|
328
|
+
except Exception as exc:
|
|
329
|
+
self.logger.error(
|
|
330
|
+
f"Failed to generate structured output schema for Gemini: {exc}"
|
|
331
|
+
)
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
generation_config["response_mime_type"] = "application/json"
|
|
335
|
+
generation_config["response_schema"] = fixed_schema
|
|
336
|
+
return fixed_schema
|
|
337
|
+
|
|
338
|
+
def _build_tools(self, tool_type: str, filter_names: Optional[List[str]] = None) -> Optional[List[types.Tool]]:
|
|
339
|
+
"""Build tools based on the specified type."""
|
|
340
|
+
if tool_type == "custom_functions":
|
|
341
|
+
# migrate to use abstractool + tool definition:
|
|
342
|
+
# Group function declarations by their category
|
|
343
|
+
declarations_by_category = defaultdict(list)
|
|
344
|
+
for tool in self.tool_manager.all_tools():
|
|
345
|
+
tool_name = tool.name
|
|
346
|
+
if filter_names is not None and tool_name not in filter_names:
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
tool_name = tool.name
|
|
350
|
+
category = getattr(tool, 'category', 'tools')
|
|
351
|
+
if isinstance(tool, AbstractTool):
|
|
352
|
+
full_schema = tool.get_tool_schema()
|
|
353
|
+
tool_description = full_schema.get("description", tool.description)
|
|
354
|
+
# Extract ONLY the parameters part
|
|
355
|
+
schema = full_schema.get("parameters", {}).copy()
|
|
356
|
+
# Clean the schema for Google compatibility
|
|
357
|
+
schema = self.clean_google_schema(schema)
|
|
358
|
+
elif isinstance(tool, ToolDefinition):
|
|
359
|
+
tool_description = tool.description
|
|
360
|
+
schema = self.clean_google_schema(tool.input_schema.copy())
|
|
361
|
+
else:
|
|
362
|
+
# Fallback for other tool types
|
|
363
|
+
tool_description = getattr(tool, 'description', f"Tool: {tool_name}")
|
|
364
|
+
schema = getattr(tool, 'input_schema', {})
|
|
365
|
+
schema = self.clean_google_schema(schema)
|
|
366
|
+
|
|
367
|
+
# Ensure we have a valid parameters schema
|
|
368
|
+
if not schema:
|
|
369
|
+
schema = {
|
|
370
|
+
"type": "object",
|
|
371
|
+
"properties": {},
|
|
372
|
+
"required": []
|
|
373
|
+
}
|
|
374
|
+
try:
|
|
375
|
+
declaration = types.FunctionDeclaration(
|
|
376
|
+
name=tool_name,
|
|
377
|
+
description=tool_description,
|
|
378
|
+
parameters=self._fix_tool_schema(schema)
|
|
379
|
+
)
|
|
380
|
+
declarations_by_category[category].append(declaration)
|
|
381
|
+
except Exception as e:
|
|
382
|
+
self.logger.error(f"Error creating function declaration for {tool_name}: {e}")
|
|
383
|
+
# Skip this tool if it can't be created
|
|
384
|
+
continue
|
|
385
|
+
|
|
386
|
+
tool_list = []
|
|
387
|
+
for category, declarations in declarations_by_category.items():
|
|
388
|
+
if declarations:
|
|
389
|
+
tool_list.append(
|
|
390
|
+
types.Tool(
|
|
391
|
+
function_declarations=declarations
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
return tool_list
|
|
395
|
+
elif tool_type == "builtin_tools":
|
|
396
|
+
return [
|
|
397
|
+
types.Tool(
|
|
398
|
+
google_search=types.GoogleSearch()
|
|
399
|
+
),
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
def _extract_function_calls(self, response) -> List:
|
|
405
|
+
"""Extract function calls from response - handles both proper function calls AND code blocks."""
|
|
406
|
+
function_calls = []
|
|
407
|
+
|
|
408
|
+
try:
|
|
409
|
+
if (response.candidates and
|
|
410
|
+
len(response.candidates) > 0 and
|
|
411
|
+
response.candidates[0].content and
|
|
412
|
+
response.candidates[0].content.parts):
|
|
413
|
+
|
|
414
|
+
for part in response.candidates[0].content.parts:
|
|
415
|
+
# First, check for proper function calls
|
|
416
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
417
|
+
function_calls.append(part.function_call)
|
|
418
|
+
self.logger.debug(f"Found proper function call: {part.function_call.name}")
|
|
419
|
+
|
|
420
|
+
# Second, check for text that contains tool code blocks
|
|
421
|
+
elif hasattr(part, 'text') and part.text and '```tool_code' in part.text:
|
|
422
|
+
self.logger.info("Found tool code block - parsing as function call")
|
|
423
|
+
code_block_calls = self._parse_tool_code_blocks(part.text)
|
|
424
|
+
function_calls.extend(code_block_calls)
|
|
425
|
+
|
|
426
|
+
except (AttributeError, IndexError) as e:
|
|
427
|
+
self.logger.debug(f"Error extracting function calls: {e}")
|
|
428
|
+
|
|
429
|
+
self.logger.debug(f"Total function calls extracted: {len(function_calls)}")
|
|
430
|
+
return function_calls
|
|
431
|
+
|
|
432
|
+
async def _handle_stateless_function_calls(
|
|
433
|
+
self,
|
|
434
|
+
response,
|
|
435
|
+
model: str,
|
|
436
|
+
contents: List,
|
|
437
|
+
config,
|
|
438
|
+
all_tool_calls: List[ToolCall],
|
|
439
|
+
original_prompt: Optional[str] = None
|
|
440
|
+
) -> Any:
|
|
441
|
+
"""Handle function calls in stateless mode (single request-response)."""
|
|
442
|
+
function_calls = self._extract_function_calls(response)
|
|
443
|
+
|
|
444
|
+
if not function_calls:
|
|
445
|
+
return response
|
|
446
|
+
|
|
447
|
+
# Execute function calls
|
|
448
|
+
tool_call_objects = []
|
|
449
|
+
for fc in function_calls:
|
|
450
|
+
tc = ToolCall(
|
|
451
|
+
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
452
|
+
name=fc.name,
|
|
453
|
+
arguments=dict(fc.args)
|
|
454
|
+
)
|
|
455
|
+
tool_call_objects.append(tc)
|
|
456
|
+
|
|
457
|
+
start_time = time.time()
|
|
458
|
+
tool_execution_tasks = [
|
|
459
|
+
self._execute_tool(fc.name, dict(fc.args)) for fc in function_calls
|
|
460
|
+
]
|
|
461
|
+
tool_results = await asyncio.gather(
|
|
462
|
+
*tool_execution_tasks,
|
|
463
|
+
return_exceptions=True
|
|
464
|
+
)
|
|
465
|
+
execution_time = time.time() - start_time
|
|
466
|
+
|
|
467
|
+
for tc, result in zip(tool_call_objects, tool_results):
|
|
468
|
+
tc.execution_time = execution_time / len(tool_call_objects)
|
|
469
|
+
if isinstance(result, Exception):
|
|
470
|
+
tc.error = str(result)
|
|
471
|
+
else:
|
|
472
|
+
tc.result = result
|
|
473
|
+
|
|
474
|
+
all_tool_calls.extend(tool_call_objects)
|
|
475
|
+
|
|
476
|
+
# Prepare function responses
|
|
477
|
+
function_response_parts = []
|
|
478
|
+
for fc, result in zip(function_calls, tool_results):
|
|
479
|
+
if isinstance(result, Exception):
|
|
480
|
+
response_content = f"Error: {str(result)}"
|
|
481
|
+
else:
|
|
482
|
+
response_content = str(result.get('result', result) if isinstance(result, dict) else result)
|
|
483
|
+
|
|
484
|
+
function_response_parts.append(
|
|
485
|
+
Part(
|
|
486
|
+
function_response=types.FunctionResponse(
|
|
487
|
+
name=fc.name,
|
|
488
|
+
response={"result": response_content}
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if summary_part := self._create_tool_summary_part(
|
|
494
|
+
function_calls,
|
|
495
|
+
tool_results,
|
|
496
|
+
original_prompt
|
|
497
|
+
):
|
|
498
|
+
function_response_parts.append(summary_part)
|
|
499
|
+
|
|
500
|
+
# Add function call and responses to conversation
|
|
501
|
+
contents.append({
|
|
502
|
+
"role": "model",
|
|
503
|
+
"parts": [{"function_call": fc} for fc in function_calls]
|
|
504
|
+
})
|
|
505
|
+
contents.append({
|
|
506
|
+
"role": "user",
|
|
507
|
+
"parts": function_response_parts
|
|
508
|
+
})
|
|
509
|
+
|
|
510
|
+
# Generate final response
|
|
511
|
+
final_response = await self.client.aio.models.generate_content(
|
|
512
|
+
model=model,
|
|
513
|
+
contents=contents,
|
|
514
|
+
config=config
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
return final_response
|
|
518
|
+
|
|
519
|
+
def _process_tool_result_for_api(self, result) -> dict:
|
|
520
|
+
"""
|
|
521
|
+
Process tool result for Google Function Calling API compatibility.
|
|
522
|
+
This method serializes various Python objects into a JSON-compatible
|
|
523
|
+
dictionary for the Google GenAI API.
|
|
524
|
+
"""
|
|
525
|
+
# 1. Handle exceptions and special wrapper types first
|
|
526
|
+
if isinstance(result, Exception):
|
|
527
|
+
return {"result": f"Tool execution failed: {str(result)}", "error": True}
|
|
528
|
+
|
|
529
|
+
# Handle ToolResult wrapper
|
|
530
|
+
if isinstance(result, ToolResult):
|
|
531
|
+
content = result.result
|
|
532
|
+
if result.metadata and 'stdout' in result.metadata:
|
|
533
|
+
# Prioritize stdout if exists
|
|
534
|
+
content = result.metadata['stdout']
|
|
535
|
+
content = result.metadata['stdout']
|
|
536
|
+
result = content # The actual result to process is the content
|
|
537
|
+
|
|
538
|
+
# Handle string results early (no conversion needed)
|
|
539
|
+
if isinstance(result, str):
|
|
540
|
+
if not result.strip():
|
|
541
|
+
return {"result": "Code executed successfully (no output)"}
|
|
542
|
+
return {"result": result}
|
|
543
|
+
|
|
544
|
+
# Convert complex types to basic Python types
|
|
545
|
+
clean_result = result
|
|
546
|
+
|
|
547
|
+
if isinstance(result, pd.DataFrame):
|
|
548
|
+
# Convert DataFrame to records and ensure all keys are strings
|
|
549
|
+
# This handles DataFrames with integer or other non-string column names
|
|
550
|
+
records = result.to_dict(orient='records')
|
|
551
|
+
clean_result = [
|
|
552
|
+
{str(k): v for k, v in record.items()}
|
|
553
|
+
for record in records
|
|
554
|
+
]
|
|
555
|
+
elif isinstance(result, list):
|
|
556
|
+
# Handle lists (including lists of Pydantic models)
|
|
557
|
+
clean_result = []
|
|
558
|
+
for item in result:
|
|
559
|
+
if hasattr(item, 'model_dump'): # Pydantic v2
|
|
560
|
+
clean_result.append(item.model_dump())
|
|
561
|
+
elif hasattr(item, 'dict'): # Pydantic v1
|
|
562
|
+
clean_result.append(item.dict())
|
|
563
|
+
else:
|
|
564
|
+
clean_result.append(item)
|
|
565
|
+
elif hasattr(result, 'model_dump'): # Pydantic v2 single model
|
|
566
|
+
clean_result = result.model_dump()
|
|
567
|
+
elif hasattr(result, 'dict'): # Pydantic v1 single model
|
|
568
|
+
clean_result = result.dict()
|
|
569
|
+
|
|
570
|
+
# 4. Attempt to serialize the processed result
|
|
571
|
+
try:
|
|
572
|
+
serialized = self._json.dumps(clean_result)
|
|
573
|
+
json_compatible_result = self._json.loads(serialized)
|
|
574
|
+
except Exception as e:
|
|
575
|
+
# This is the fallback for non-serializable objects (like PriceOutput)
|
|
576
|
+
self.logger.warning(
|
|
577
|
+
f"Could not serialize result of type {type(clean_result)} to JSON: {e}. "
|
|
578
|
+
"Falling back to string representation."
|
|
579
|
+
)
|
|
580
|
+
json_compatible_result = str(clean_result)
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
# Wrap for Google Function Calling format
|
|
584
|
+
if isinstance(json_compatible_result, dict) and 'result' in json_compatible_result:
|
|
585
|
+
return json_compatible_result
|
|
586
|
+
else:
|
|
587
|
+
return {"result": json_compatible_result}
|
|
588
|
+
|
|
589
|
+
def _summarize_tool_result(self, result: Any, max_length: int = 1200) -> str:
|
|
590
|
+
"""Create a short, human-readable summary of a tool result."""
|
|
591
|
+
|
|
592
|
+
try:
|
|
593
|
+
if isinstance(result, Exception):
|
|
594
|
+
summary = f"Error: {result}"
|
|
595
|
+
elif isinstance(result, pd.DataFrame):
|
|
596
|
+
preview = result.head(5)
|
|
597
|
+
summary = preview.to_string(index=True)
|
|
598
|
+
elif hasattr(result, 'model_dump'):
|
|
599
|
+
summary = self._json.dumps(result.model_dump())
|
|
600
|
+
elif isinstance(result, (dict, list)):
|
|
601
|
+
summary = self._json.dumps(result)
|
|
602
|
+
else:
|
|
603
|
+
summary = str(result)
|
|
604
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
605
|
+
summary = f"Unable to summarize result: {exc}"
|
|
606
|
+
|
|
607
|
+
summary = summary.strip() or "[empty result]"
|
|
608
|
+
if len(summary) > max_length:
|
|
609
|
+
summary = summary[:max_length].rstrip() + "…"
|
|
610
|
+
return summary
|
|
611
|
+
|
|
612
|
+
def _create_tool_summary_part(
|
|
613
|
+
self,
|
|
614
|
+
function_calls,
|
|
615
|
+
tool_results,
|
|
616
|
+
original_prompt: Optional[str] = None
|
|
617
|
+
) -> Optional[Part]:
|
|
618
|
+
"""Build a textual summary of tool outputs for the model to read easily."""
|
|
619
|
+
|
|
620
|
+
if not function_calls or not tool_results:
|
|
621
|
+
return None
|
|
622
|
+
|
|
623
|
+
summary_lines = ["Tool execution summaries:"]
|
|
624
|
+
for fc, result in zip(function_calls, tool_results):
|
|
625
|
+
summary_lines.append(
|
|
626
|
+
f"- {fc.name}: {self._summarize_tool_result(result)}"
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
if original_prompt:
|
|
630
|
+
summary_lines.append(f"Original Request: {original_prompt}")
|
|
631
|
+
|
|
632
|
+
summary_lines.append(
|
|
633
|
+
"Use the information above to craft the final response without running redundant tool calls."
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
summary_text = "\n".join(summary_lines)
|
|
637
|
+
return Part(text=summary_text)
|
|
638
|
+
|
|
639
|
+
async def _handle_multiturn_function_calls(
|
|
640
|
+
self,
|
|
641
|
+
chat,
|
|
642
|
+
initial_response,
|
|
643
|
+
all_tool_calls: List[ToolCall],
|
|
644
|
+
original_prompt: Optional[str] = None,
|
|
645
|
+
model: str = None,
|
|
646
|
+
max_iterations: int = 10,
|
|
647
|
+
config: GenerateContentConfig = None,
|
|
648
|
+
max_retries: int = 1,
|
|
649
|
+
lazy_loading: bool = False,
|
|
650
|
+
active_tool_names: Optional[set] = None,
|
|
651
|
+
) -> Any:
|
|
652
|
+
"""
|
|
653
|
+
Simple multi-turn function calling - just keep going until no more function calls.
|
|
654
|
+
"""
|
|
655
|
+
current_response = initial_response
|
|
656
|
+
current_config = config
|
|
657
|
+
iteration = 0
|
|
658
|
+
|
|
659
|
+
if active_tool_names is None:
|
|
660
|
+
active_tool_names = set()
|
|
661
|
+
|
|
662
|
+
model = model or self.model
|
|
663
|
+
self.logger.info("Starting simple multi-turn function calling loop")
|
|
664
|
+
|
|
665
|
+
while iteration < max_iterations:
|
|
666
|
+
iteration += 1
|
|
667
|
+
|
|
668
|
+
# Get function calls (including converted from tool_code)
|
|
669
|
+
function_calls = self._get_function_calls_from_response(current_response)
|
|
670
|
+
if not function_calls:
|
|
671
|
+
# Check if we have any text content in the response
|
|
672
|
+
final_text = self._safe_extract_text(current_response)
|
|
673
|
+
self.logger.notice(f"🎯 Final Response from Gemini: {final_text[:200]}...")
|
|
674
|
+
if not final_text and all_tool_calls:
|
|
675
|
+
self.logger.warning(
|
|
676
|
+
"Final response is empty after tool execution, generating summary..."
|
|
677
|
+
)
|
|
678
|
+
try:
|
|
679
|
+
synthesis_prompt = """
|
|
680
|
+
Please now generate the complete response based on all the information gathered from the tools.
|
|
681
|
+
Provide a comprehensive answer to the original request.
|
|
682
|
+
Synthesize the data and provide insights, analysis, and conclusions as appropriate.
|
|
683
|
+
"""
|
|
684
|
+
current_response = await chat.send_message(
|
|
685
|
+
synthesis_prompt,
|
|
686
|
+
config=current_config
|
|
687
|
+
)
|
|
688
|
+
# Check if this worked
|
|
689
|
+
synthesis_text = self._safe_extract_text(current_response)
|
|
690
|
+
if synthesis_text:
|
|
691
|
+
self.logger.info("Successfully generated synthesis response")
|
|
692
|
+
else:
|
|
693
|
+
self.logger.warning("Synthesis attempt also returned empty response")
|
|
694
|
+
except Exception as e:
|
|
695
|
+
self.logger.error(f"Synthesis attempt failed: {e}")
|
|
696
|
+
|
|
697
|
+
self.logger.info(
|
|
698
|
+
f"No function calls found - completed after {iteration-1} iterations"
|
|
699
|
+
)
|
|
700
|
+
break
|
|
701
|
+
|
|
702
|
+
self.logger.info(
|
|
703
|
+
f"Iteration {iteration}: Processing {len(function_calls)} function calls"
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Execute function calls
|
|
707
|
+
tool_call_objects = []
|
|
708
|
+
for fc in function_calls:
|
|
709
|
+
tc = ToolCall(
|
|
710
|
+
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
711
|
+
name=fc.name,
|
|
712
|
+
arguments=dict(fc.args) if hasattr(fc.args, 'items') else fc.args
|
|
713
|
+
)
|
|
714
|
+
tool_call_objects.append(tc)
|
|
715
|
+
|
|
716
|
+
# Execute tools
|
|
717
|
+
start_time = time.time()
|
|
718
|
+
tool_execution_tasks = [
|
|
719
|
+
self._execute_tool(fc.name, dict(fc.args) if hasattr(fc.args, 'items') else fc.args)
|
|
720
|
+
for fc in function_calls
|
|
721
|
+
]
|
|
722
|
+
tool_results = await asyncio.gather(*tool_execution_tasks, return_exceptions=True)
|
|
723
|
+
execution_time = time.time() - start_time
|
|
724
|
+
|
|
725
|
+
# Lazy Loading Check
|
|
726
|
+
if lazy_loading:
|
|
727
|
+
found_new = False
|
|
728
|
+
for fc, result in zip(function_calls, tool_results):
|
|
729
|
+
if fc.name == "search_tools" and isinstance(result, str):
|
|
730
|
+
new_tools = self._check_new_tools(fc.name, result)
|
|
731
|
+
for nt in new_tools:
|
|
732
|
+
if nt not in active_tool_names:
|
|
733
|
+
active_tool_names.add(nt)
|
|
734
|
+
found_new = True
|
|
735
|
+
|
|
736
|
+
if found_new:
|
|
737
|
+
# Rebuild tools with expanded set
|
|
738
|
+
new_tools_list = self._build_tools("custom_functions", filter_names=list(active_tool_names))
|
|
739
|
+
current_config.tools = new_tools_list
|
|
740
|
+
self.logger.info(f"Updated tools for next turn. Count: {len(active_tool_names)}")
|
|
741
|
+
|
|
742
|
+
# Update tool call objects
|
|
743
|
+
for tc, result in zip(tool_call_objects, tool_results):
|
|
744
|
+
tc.execution_time = execution_time / len(tool_call_objects)
|
|
745
|
+
if isinstance(result, Exception):
|
|
746
|
+
tc.error = str(result)
|
|
747
|
+
self.logger.error(f"Tool {tc.name} failed: {result}")
|
|
748
|
+
else:
|
|
749
|
+
tc.result = result
|
|
750
|
+
# self.logger.info(f"Tool {tc.name} result: {result}")
|
|
751
|
+
|
|
752
|
+
all_tool_calls.extend(tool_call_objects)
|
|
753
|
+
function_response_parts = []
|
|
754
|
+
for fc, result in zip(function_calls, tool_results):
|
|
755
|
+
tool_id = fc.id or f"call_{uuid.uuid4().hex[:8]}"
|
|
756
|
+
self.logger.notice(f"🔍 Tool: {fc.name}")
|
|
757
|
+
self.logger.notice(f"📤 Raw Result Type: {type(result)}")
|
|
758
|
+
|
|
759
|
+
try:
|
|
760
|
+
response_content = self._process_tool_result_for_api(result)
|
|
761
|
+
self.logger.info(f"📦 Processed for API: {response_content}")
|
|
762
|
+
|
|
763
|
+
function_response_parts.append(
|
|
764
|
+
Part(
|
|
765
|
+
function_response=types.FunctionResponse(
|
|
766
|
+
id=tool_id,
|
|
767
|
+
name=fc.name,
|
|
768
|
+
response=response_content
|
|
769
|
+
)
|
|
770
|
+
)
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
except Exception as e:
|
|
774
|
+
self.logger.error(f"Error processing result for tool {fc.name}: {e}")
|
|
775
|
+
function_response_parts.append(
|
|
776
|
+
Part(
|
|
777
|
+
function_response=types.FunctionResponse(
|
|
778
|
+
id=tool_id,
|
|
779
|
+
name=fc.name,
|
|
780
|
+
response={"result": f"Tool error: {str(e)}", "error": True}
|
|
781
|
+
)
|
|
782
|
+
)
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
summary_part = self._create_tool_summary_part(
|
|
786
|
+
function_calls,
|
|
787
|
+
tool_results,
|
|
788
|
+
original_prompt
|
|
789
|
+
)
|
|
790
|
+
# Combine the tool results with the textual summary prompt
|
|
791
|
+
next_prompt_parts = function_response_parts.copy()
|
|
792
|
+
if summary_part:
|
|
793
|
+
next_prompt_parts.append(summary_part)
|
|
794
|
+
|
|
795
|
+
# Send responses back
|
|
796
|
+
retry_count = 0
|
|
797
|
+
try:
|
|
798
|
+
self.logger.debug(
|
|
799
|
+
f"Sending {len(next_prompt_parts)} responses back to model"
|
|
800
|
+
)
|
|
801
|
+
while retry_count < max_retries:
|
|
802
|
+
try:
|
|
803
|
+
current_response = await chat.send_message(
|
|
804
|
+
next_prompt_parts,
|
|
805
|
+
config=current_config
|
|
806
|
+
)
|
|
807
|
+
finish_reason = getattr(current_response.candidates[0], 'finish_reason', None)
|
|
808
|
+
if finish_reason:
|
|
809
|
+
if finish_reason.name == "MAX_TOKENS" and current_config.max_output_tokens < 8192:
|
|
810
|
+
self.logger.warning(
|
|
811
|
+
f"Hit MAX_TOKENS limit. Retrying with increased token limit."
|
|
812
|
+
)
|
|
813
|
+
retry_count += 1
|
|
814
|
+
current_config.max_output_tokens = 8192
|
|
815
|
+
continue
|
|
816
|
+
elif finish_reason.name == "MALFORMED_FUNCTION_CALL":
|
|
817
|
+
self.logger.warning(
|
|
818
|
+
f"Malformed function call detected. Retrying..."
|
|
819
|
+
)
|
|
820
|
+
retry_count += 1
|
|
821
|
+
await asyncio.sleep(2 ** retry_count)
|
|
822
|
+
continue
|
|
823
|
+
break
|
|
824
|
+
except Exception as e:
|
|
825
|
+
self.logger.error(f"Error sending message: {e}")
|
|
826
|
+
retry_count += 1
|
|
827
|
+
await asyncio.sleep(2 ** retry_count) # Exponential backoff
|
|
828
|
+
if (retry_count + 1) >= max_retries:
|
|
829
|
+
self.logger.error("Max retries reached, aborting")
|
|
830
|
+
raise e
|
|
831
|
+
|
|
832
|
+
# Check for UNEXPECTED_TOOL_CALL error
|
|
833
|
+
if (hasattr(current_response, 'candidates') and
|
|
834
|
+
current_response.candidates and
|
|
835
|
+
hasattr(current_response.candidates[0], 'finish_reason')):
|
|
836
|
+
|
|
837
|
+
finish_reason = current_response.candidates[0].finish_reason
|
|
838
|
+
|
|
839
|
+
if str(finish_reason) == 'FinishReason.UNEXPECTED_TOOL_CALL':
|
|
840
|
+
self.logger.warning("Received UNEXPECTED_TOOL_CALL")
|
|
841
|
+
|
|
842
|
+
# Debug what we got back
|
|
843
|
+
if hasattr(current_response, 'text'):
|
|
844
|
+
try:
|
|
845
|
+
preview = current_response.text[:100] if current_response.text else "No text"
|
|
846
|
+
self.logger.debug(f"Response preview: {preview}")
|
|
847
|
+
except:
|
|
848
|
+
self.logger.debug("Could not preview response text")
|
|
849
|
+
|
|
850
|
+
except Exception as e:
|
|
851
|
+
self.logger.error(f"Failed to send responses back: {e}")
|
|
852
|
+
break
|
|
853
|
+
|
|
854
|
+
self.logger.info(f"Completed with {len(all_tool_calls)} total tool calls")
|
|
855
|
+
return current_response
|
|
856
|
+
|
|
857
|
+
def _parse_tool_code_blocks(self, text: str) -> List:
|
|
858
|
+
"""Convert tool_code blocks to function call objects."""
|
|
859
|
+
function_calls = []
|
|
860
|
+
|
|
861
|
+
if '```tool_code' not in text:
|
|
862
|
+
return function_calls
|
|
863
|
+
|
|
864
|
+
# Simple regex to extract tool calls
|
|
865
|
+
pattern = r'```tool_code\s*\n\s*print\(default_api\.(\w+)\((.*?)\)\)\s*\n\s*```'
|
|
866
|
+
matches = re.findall(pattern, text, re.DOTALL)
|
|
867
|
+
|
|
868
|
+
for tool_name, args_str in matches:
|
|
869
|
+
self.logger.debug(f"Converting tool_code to function call: {tool_name}")
|
|
870
|
+
try:
|
|
871
|
+
# Parse arguments like: a = 9310, b = 3, operation = "divide"
|
|
872
|
+
args = {}
|
|
873
|
+
for arg_part in args_str.split(','):
|
|
874
|
+
if '=' in arg_part:
|
|
875
|
+
key, value = arg_part.split('=', 1)
|
|
876
|
+
key = key.strip()
|
|
877
|
+
value = value.strip().strip('"\'') # Remove quotes
|
|
878
|
+
|
|
879
|
+
# Try to convert to number
|
|
880
|
+
try:
|
|
881
|
+
if '.' in value:
|
|
882
|
+
args[key] = float(value)
|
|
883
|
+
else:
|
|
884
|
+
args[key] = int(value)
|
|
885
|
+
except ValueError:
|
|
886
|
+
args[key] = value # Keep as string
|
|
887
|
+
# extract tool from Tool Manager
|
|
888
|
+
tool = self.tool_manager.get_tool(tool_name)
|
|
889
|
+
if tool:
|
|
890
|
+
# Create function call
|
|
891
|
+
fc = types.FunctionCall(
|
|
892
|
+
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
893
|
+
name=tool_name,
|
|
894
|
+
args=args
|
|
895
|
+
)
|
|
896
|
+
function_calls.append(fc)
|
|
897
|
+
self.logger.info(f"Created function call: {tool_name}({args})")
|
|
898
|
+
|
|
899
|
+
except Exception as e:
|
|
900
|
+
self.logger.error(f"Failed to parse tool_code: {e}")
|
|
901
|
+
|
|
902
|
+
return function_calls
|
|
903
|
+
|
|
904
|
+
def _get_function_calls_from_response(self, response) -> List:
|
|
905
|
+
"""Get function calls from response - handles both proper calls and tool_code blocks."""
|
|
906
|
+
function_calls = []
|
|
907
|
+
|
|
908
|
+
try:
|
|
909
|
+
if (response.candidates and
|
|
910
|
+
response.candidates[0].content and
|
|
911
|
+
response.candidates[0].content.parts):
|
|
912
|
+
|
|
913
|
+
for part in response.candidates[0].content.parts:
|
|
914
|
+
# Check for proper function calls first
|
|
915
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
916
|
+
function_calls.append(part.function_call)
|
|
917
|
+
self.logger.debug(
|
|
918
|
+
f"Found proper function call: {part.function_call.name}"
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Check for tool_code in text parts
|
|
922
|
+
elif hasattr(part, 'text') and part.text and '```tool_code' in part.text:
|
|
923
|
+
self.logger.info("Found tool_code block - converting to function call")
|
|
924
|
+
code_function_calls = self._parse_tool_code_blocks(part.text)
|
|
925
|
+
function_calls.extend(code_function_calls)
|
|
926
|
+
|
|
927
|
+
except Exception as e:
|
|
928
|
+
self.logger.error(f"Error getting function calls: {e}")
|
|
929
|
+
|
|
930
|
+
self.logger.info(f"Total function calls found: {len(function_calls)}")
|
|
931
|
+
return function_calls
|
|
932
|
+
|
|
933
|
+
def _safe_extract_text(self, response) -> str:
|
|
934
|
+
"""
|
|
935
|
+
Enhanced text extraction that handles reasoning models and mixed content warnings.
|
|
936
|
+
|
|
937
|
+
This method tries multiple approaches to extract text from Google GenAI responses,
|
|
938
|
+
handling special cases like thought_signature parts from reasoning models.
|
|
939
|
+
"""
|
|
940
|
+
|
|
941
|
+
# Pre-check for function calls to avoid library warnings when accessing .text
|
|
942
|
+
has_function_call = False
|
|
943
|
+
try:
|
|
944
|
+
if (hasattr(response, 'candidates') and response.candidates and
|
|
945
|
+
len(response.candidates) > 0 and hasattr(response.candidates[0], 'content') and
|
|
946
|
+
response.candidates[0].content and hasattr(response.candidates[0].content, 'parts') and
|
|
947
|
+
response.candidates[0].content.parts):
|
|
948
|
+
for part in response.candidates[0].content.parts:
|
|
949
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
950
|
+
has_function_call = True
|
|
951
|
+
break
|
|
952
|
+
except Exception:
|
|
953
|
+
pass
|
|
954
|
+
|
|
955
|
+
# Method 1: Try response.text first (fastest path)
|
|
956
|
+
# Skip if we found a function call, as accessing .text triggers a warning in the library
|
|
957
|
+
if not has_function_call:
|
|
958
|
+
try:
|
|
959
|
+
if hasattr(response, 'text') and response.text:
|
|
960
|
+
if (text := response.text.strip()):
|
|
961
|
+
self.logger.debug(
|
|
962
|
+
f"Extracted text via response.text: '{text[:100]}...'"
|
|
963
|
+
)
|
|
964
|
+
return text
|
|
965
|
+
except Exception as e:
|
|
966
|
+
# This is expected with reasoning models that have mixed content
|
|
967
|
+
self.logger.debug(
|
|
968
|
+
f"response.text failed (normal for reasoning models): {e}"
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
# Method 2: Manual extraction from parts (more robust)
|
|
972
|
+
try:
|
|
973
|
+
if (hasattr(response, 'candidates') and response.candidates and len(response.candidates) > 0 and
|
|
974
|
+
hasattr(response.candidates[0], 'content') and response.candidates[0].content and
|
|
975
|
+
hasattr(response.candidates[0].content, 'parts') and response.candidates[0].content.parts):
|
|
976
|
+
|
|
977
|
+
text_parts = []
|
|
978
|
+
thought_parts_found = 0
|
|
979
|
+
|
|
980
|
+
# Extract text from each part, handling special cases
|
|
981
|
+
for part in response.candidates[0].content.parts:
|
|
982
|
+
# Check for regular text content
|
|
983
|
+
if hasattr(part, 'text') and part.text:
|
|
984
|
+
if (clean_text := part.text.strip()):
|
|
985
|
+
text_parts.append(clean_text)
|
|
986
|
+
self.logger.debug(
|
|
987
|
+
f"Found text part: '{clean_text[:50]}...'"
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Log non-text parts but don't extract them
|
|
991
|
+
elif hasattr(part, 'thought_signature'):
|
|
992
|
+
thought_parts_found += 1
|
|
993
|
+
self.logger.debug(
|
|
994
|
+
"Found thought_signature part (reasoning model internal thought)"
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
# Log reasoning model detection
|
|
998
|
+
if thought_parts_found > 0:
|
|
999
|
+
self.logger.debug(
|
|
1000
|
+
f"Detected reasoning model with {thought_parts_found} thought parts"
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# Combine text parts
|
|
1004
|
+
if text_parts:
|
|
1005
|
+
if (combined_text := "".join(text_parts).strip()):
|
|
1006
|
+
self.logger.debug(
|
|
1007
|
+
f"Successfully extracted text from {len(text_parts)} parts"
|
|
1008
|
+
)
|
|
1009
|
+
return combined_text
|
|
1010
|
+
else:
|
|
1011
|
+
self.logger.debug("No text parts found in response parts")
|
|
1012
|
+
|
|
1013
|
+
except Exception as e:
|
|
1014
|
+
self.logger.error(f"Manual text extraction failed: {e}")
|
|
1015
|
+
|
|
1016
|
+
# Method 3: Deep inspection for debugging (fallback)
|
|
1017
|
+
try:
|
|
1018
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
1019
|
+
candidate = response.candidates[0] if len(response.candidates) > 0 else None
|
|
1020
|
+
if candidate:
|
|
1021
|
+
if hasattr(candidate, 'finish_reason'):
|
|
1022
|
+
finish_reason = str(candidate.finish_reason)
|
|
1023
|
+
self.logger.debug(f"Response finish reason: {finish_reason}")
|
|
1024
|
+
if 'MAX_TOKENS' in finish_reason:
|
|
1025
|
+
self.logger.warning("Response truncated due to token limit")
|
|
1026
|
+
elif 'SAFETY' in finish_reason:
|
|
1027
|
+
self.logger.warning("Response blocked by safety filters")
|
|
1028
|
+
elif 'STOP' in finish_reason:
|
|
1029
|
+
self.logger.debug("Response completed normally but no text found")
|
|
1030
|
+
|
|
1031
|
+
if hasattr(candidate, 'content') and candidate.content:
|
|
1032
|
+
if hasattr(candidate.content, 'parts'):
|
|
1033
|
+
parts_count = len(candidate.content.parts) if candidate.content.parts else 0
|
|
1034
|
+
self.logger.debug(f"Response has {parts_count} parts but no extractable text")
|
|
1035
|
+
if candidate.content.parts:
|
|
1036
|
+
part_types = []
|
|
1037
|
+
for part in candidate.content.parts:
|
|
1038
|
+
part_attrs = [attr for attr in dir(part)
|
|
1039
|
+
if not attr.startswith('_') and hasattr(part, attr) and getattr(part, attr)]
|
|
1040
|
+
part_types.append(part_attrs)
|
|
1041
|
+
self.logger.debug(f"Part attribute types found: {part_types}")
|
|
1042
|
+
|
|
1043
|
+
except Exception as e:
|
|
1044
|
+
self.logger.error(f"Deep inspection failed: {e}")
|
|
1045
|
+
|
|
1046
|
+
# Method 4: Final fallback - return empty string with clear logging
|
|
1047
|
+
self.logger.warning(
|
|
1048
|
+
"Could not extract any text from response using any method"
|
|
1049
|
+
)
|
|
1050
|
+
return ""
|
|
1051
|
+
|
|
1052
|
+
async def ask(
|
|
1053
|
+
self,
|
|
1054
|
+
prompt: str,
|
|
1055
|
+
model: Union[str, GoogleModel] = None,
|
|
1056
|
+
max_tokens: Optional[int] = None,
|
|
1057
|
+
temperature: Optional[float] = None,
|
|
1058
|
+
files: Optional[List[Union[str, Path]]] = None,
|
|
1059
|
+
system_prompt: Optional[str] = None,
|
|
1060
|
+
structured_output: Union[type, StructuredOutputConfig] = None,
|
|
1061
|
+
user_id: Optional[str] = None,
|
|
1062
|
+
session_id: Optional[str] = None,
|
|
1063
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
1064
|
+
use_tools: Optional[bool] = None,
|
|
1065
|
+
stateless: bool = False,
|
|
1066
|
+
deep_research: bool = False,
|
|
1067
|
+
background: bool = False,
|
|
1068
|
+
file_search_store_names: Optional[List[str]] = None,
|
|
1069
|
+
lazy_loading: bool = False,
|
|
1070
|
+
**kwargs
|
|
1071
|
+
) -> AIMessage:
|
|
1072
|
+
"""
|
|
1073
|
+
Ask a question to Google's Generative AI with support for parallel tool calls.
|
|
1074
|
+
|
|
1075
|
+
Args:
|
|
1076
|
+
prompt (str): The input prompt for the model.
|
|
1077
|
+
model (Union[str, GoogleModel]): The model to use. If None, uses the client's configured model
|
|
1078
|
+
or defaults to GEMINI_2_5_FLASH.
|
|
1079
|
+
max_tokens (int): Maximum number of tokens in the response.
|
|
1080
|
+
temperature (float): Sampling temperature for response generation.
|
|
1081
|
+
files (Optional[List[Union[str, Path]]]): Optional files to include in the request.
|
|
1082
|
+
system_prompt (Optional[str]): Optional system prompt to guide the model.
|
|
1083
|
+
structured_output (Union[type, StructuredOutputConfig]): Optional structured output configuration.
|
|
1084
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
1085
|
+
session_id: Optional session identifier for tracking.
|
|
1086
|
+
force_tool_usage (Optional[str]): Force usage of specific tools, if needed.
|
|
1087
|
+
("custom_functions", "builtin_tools", or None)
|
|
1088
|
+
stateless (bool): If True, don't use conversation memory (stateless mode).
|
|
1089
|
+
deep_research (bool): If True, use Google's deep research agent.
|
|
1090
|
+
background (bool): If True, execute deep research in background mode.
|
|
1091
|
+
file_search_store_names (Optional[List[str]]): Names of file search stores for deep research.
|
|
1092
|
+
"""
|
|
1093
|
+
max_retries = kwargs.pop('max_retries', 1)
|
|
1094
|
+
|
|
1095
|
+
# Route to deep research if requested
|
|
1096
|
+
if deep_research:
|
|
1097
|
+
self.logger.info("Using Google Deep Research mode via interactions.create()")
|
|
1098
|
+
return await self._deep_research_ask(
|
|
1099
|
+
prompt=prompt,
|
|
1100
|
+
background=background,
|
|
1101
|
+
file_search_store_names=file_search_store_names,
|
|
1102
|
+
user_id=user_id,
|
|
1103
|
+
session_id=session_id
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
1107
|
+
# If use_tools is None, use the instance default
|
|
1108
|
+
_use_tools = use_tools if use_tools is not None else self.enable_tools
|
|
1109
|
+
if not model:
|
|
1110
|
+
model = self.model or GoogleModel.GEMINI_2_5_FLASH.value
|
|
1111
|
+
|
|
1112
|
+
# Handle case where model is passed as a tuple or list
|
|
1113
|
+
if isinstance(model, (list, tuple)):
|
|
1114
|
+
model = model[0]
|
|
1115
|
+
|
|
1116
|
+
# Generate unique turn ID for tracking
|
|
1117
|
+
turn_id = str(uuid.uuid4())
|
|
1118
|
+
original_prompt = prompt
|
|
1119
|
+
|
|
1120
|
+
# Prepare conversation context using unified memory system
|
|
1121
|
+
conversation_history = None
|
|
1122
|
+
messages = []
|
|
1123
|
+
|
|
1124
|
+
# Use the abstract method to prepare conversation context
|
|
1125
|
+
if stateless:
|
|
1126
|
+
# For stateless mode, skip conversation memory
|
|
1127
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
1128
|
+
conversation_history = None
|
|
1129
|
+
else:
|
|
1130
|
+
# Use the unified conversation context preparation from AbstractClient
|
|
1131
|
+
messages, conversation_history, system_prompt = await self._prepare_conversation_context(
|
|
1132
|
+
prompt, files, user_id, session_id, system_prompt, stateless=stateless
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
# Prepare conversation history for Google GenAI format
|
|
1136
|
+
history = []
|
|
1137
|
+
# Construct history directly from the 'messages' array, which should be in the correct format
|
|
1138
|
+
if messages:
|
|
1139
|
+
for msg in messages[:-1]: # Exclude the current user message (last in list)
|
|
1140
|
+
role = msg['role'].lower()
|
|
1141
|
+
# Assuming content is already in the format [{"type": "text", "text": "..."}]
|
|
1142
|
+
# or other GenAI Part types if files were involved.
|
|
1143
|
+
# Here, we only expect text content for history, as images/files are for the current turn.
|
|
1144
|
+
if role == 'user':
|
|
1145
|
+
# Content can be a list of dicts (for text/parts) or a single string.
|
|
1146
|
+
# Standardize to list of Parts.
|
|
1147
|
+
parts = []
|
|
1148
|
+
for part_content in msg.get('content', []):
|
|
1149
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1150
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1151
|
+
# Add other part types if necessary for history (e.g., function responses)
|
|
1152
|
+
if parts:
|
|
1153
|
+
history.append(UserContent(parts=parts))
|
|
1154
|
+
elif role in ['assistant', 'model']:
|
|
1155
|
+
parts = []
|
|
1156
|
+
for part_content in msg.get('content', []):
|
|
1157
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1158
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1159
|
+
if parts:
|
|
1160
|
+
history.append(ModelContent(parts=parts))
|
|
1161
|
+
|
|
1162
|
+
default_tokens = max_tokens or self.max_tokens or 4096
|
|
1163
|
+
generation_config = {
|
|
1164
|
+
"max_output_tokens": default_tokens,
|
|
1165
|
+
"temperature": temperature or self.temperature
|
|
1166
|
+
}
|
|
1167
|
+
base_temperature = generation_config["temperature"]
|
|
1168
|
+
|
|
1169
|
+
# Prepare structured output configuration
|
|
1170
|
+
output_config = self._get_structured_config(structured_output)
|
|
1171
|
+
|
|
1172
|
+
# Tool selection
|
|
1173
|
+
requested_tools = tools
|
|
1174
|
+
|
|
1175
|
+
if _use_tools:
|
|
1176
|
+
if requested_tools and isinstance(requested_tools, list):
|
|
1177
|
+
for tool in requested_tools:
|
|
1178
|
+
self.register_tool(tool)
|
|
1179
|
+
tool_type = "custom_functions"
|
|
1180
|
+
# if Tools, reduce temperature to avoid hallucinations.
|
|
1181
|
+
generation_config["temperature"] = 0
|
|
1182
|
+
elif _use_tools is None:
|
|
1183
|
+
# If not explicitly set, analyze the prompt to decide
|
|
1184
|
+
tool_type = self._analyze_prompt_for_tools(prompt)
|
|
1185
|
+
else:
|
|
1186
|
+
tool_type = 'builtin_tools' if _use_tools else None
|
|
1187
|
+
|
|
1188
|
+
tools = self._build_tools(tool_type) if tool_type else []
|
|
1189
|
+
|
|
1190
|
+
if _use_tools and tool_type == "custom_functions" and not tools:
|
|
1191
|
+
self.logger.info(
|
|
1192
|
+
"Tool usage requested but no tools are registered - disabling tools for this request."
|
|
1193
|
+
)
|
|
1194
|
+
_use_tools = False
|
|
1195
|
+
tool_type = None
|
|
1196
|
+
tools = []
|
|
1197
|
+
generation_config["temperature"] = base_temperature
|
|
1198
|
+
|
|
1199
|
+
use_tools = _use_tools
|
|
1200
|
+
|
|
1201
|
+
# LAZY LOADING LOGIC
|
|
1202
|
+
active_tool_names = set()
|
|
1203
|
+
if use_tools and lazy_loading:
|
|
1204
|
+
# Override initial tool selection to just search_tools
|
|
1205
|
+
active_tool_names.add("search_tools")
|
|
1206
|
+
tools = self._build_tools("custom_functions", filter_names=["search_tools"])
|
|
1207
|
+
# Add system prompt instruction
|
|
1208
|
+
search_prompt = "You have access to a library of tools. Use the 'search_tools' function to find relevant tools."
|
|
1209
|
+
system_prompt = f"{system_prompt}\n\n{search_prompt}" if system_prompt else search_prompt
|
|
1210
|
+
# Update final_config later with this new system prompt if needed,
|
|
1211
|
+
# but system_prompt is passed to GenerateContentConfig below.
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
self.logger.debug(
|
|
1215
|
+
f"Using model: {model}, max_tokens: {default_tokens}, temperature: {temperature}, "
|
|
1216
|
+
f"structured_output: {structured_output}, "
|
|
1217
|
+
f"use_tools: {_use_tools}, tool_type: {tool_type}, toolbox: {len(tools)}, "
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
use_structured_output = bool(output_config)
|
|
1221
|
+
# Google limitation: Cannot combine tools with structured output
|
|
1222
|
+
# Strategy: If both are requested, use tools first, then apply structured output to final result
|
|
1223
|
+
if _use_tools and use_structured_output:
|
|
1224
|
+
self.logger.info(
|
|
1225
|
+
"Google Gemini doesn't support tools + structured output simultaneously. "
|
|
1226
|
+
"Using tools first, then applying structured output to the final result."
|
|
1227
|
+
)
|
|
1228
|
+
structured_output_for_later = output_config
|
|
1229
|
+
# Don't set structured output in initial config
|
|
1230
|
+
output_config = None
|
|
1231
|
+
else:
|
|
1232
|
+
structured_output_for_later = None
|
|
1233
|
+
# Set structured output in generation config if no tools conflict
|
|
1234
|
+
if output_config:
|
|
1235
|
+
self._apply_structured_output_schema(generation_config, output_config)
|
|
1236
|
+
|
|
1237
|
+
# Track tool calls for the response
|
|
1238
|
+
all_tool_calls = []
|
|
1239
|
+
# Build contents for conversation
|
|
1240
|
+
contents = []
|
|
1241
|
+
|
|
1242
|
+
for msg in messages:
|
|
1243
|
+
role = "model" if msg["role"] == "assistant" else msg["role"]
|
|
1244
|
+
if role in ["user", "model"]:
|
|
1245
|
+
text_parts = [part["text"] for part in msg["content"] if "text" in part]
|
|
1246
|
+
if text_parts:
|
|
1247
|
+
contents.append({
|
|
1248
|
+
"role": role,
|
|
1249
|
+
"parts": [{"text": " ".join(text_parts)}]
|
|
1250
|
+
})
|
|
1251
|
+
|
|
1252
|
+
# Add the current prompt
|
|
1253
|
+
contents.append({
|
|
1254
|
+
"role": "user",
|
|
1255
|
+
"parts": [{"text": prompt}]
|
|
1256
|
+
})
|
|
1257
|
+
|
|
1258
|
+
chat = None
|
|
1259
|
+
if not self.client:
|
|
1260
|
+
self.client = await self.get_client()
|
|
1261
|
+
final_config = GenerateContentConfig(
|
|
1262
|
+
system_instruction=system_prompt,
|
|
1263
|
+
safety_settings=[
|
|
1264
|
+
types.SafetySetting(
|
|
1265
|
+
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
1266
|
+
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
|
1267
|
+
),
|
|
1268
|
+
types.SafetySetting(
|
|
1269
|
+
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
1270
|
+
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
|
1271
|
+
),
|
|
1272
|
+
],
|
|
1273
|
+
tools=tools,
|
|
1274
|
+
**generation_config
|
|
1275
|
+
)
|
|
1276
|
+
if stateless:
|
|
1277
|
+
# For stateless mode, handle in a single call (existing behavior)
|
|
1278
|
+
contents = []
|
|
1279
|
+
|
|
1280
|
+
for msg in messages:
|
|
1281
|
+
role = "model" if msg["role"] == "assistant" else msg["role"]
|
|
1282
|
+
if role in ["user", "model"]:
|
|
1283
|
+
text_parts = [part["text"] for part in msg["content"] if "text" in part]
|
|
1284
|
+
if text_parts:
|
|
1285
|
+
contents.append({
|
|
1286
|
+
"role": role,
|
|
1287
|
+
"parts": [{"text": " ".join(text_parts)}]
|
|
1288
|
+
})
|
|
1289
|
+
try:
|
|
1290
|
+
retry_count = 0
|
|
1291
|
+
while retry_count < max_retries:
|
|
1292
|
+
response = await self.client.aio.models.generate_content(
|
|
1293
|
+
model=model,
|
|
1294
|
+
contents=contents,
|
|
1295
|
+
config=final_config
|
|
1296
|
+
)
|
|
1297
|
+
finish_reason = getattr(response.candidates[0], 'finish_reason', None)
|
|
1298
|
+
if finish_reason and finish_reason.name == "MAX_TOKENS" and generation_config["max_output_tokens"] == 1024:
|
|
1299
|
+
retry_count += 1
|
|
1300
|
+
self.logger.warning(
|
|
1301
|
+
f"Hit MAX_TOKENS limit on stateless response. Retrying {retry_count}/{max_retries} with increased token limit."
|
|
1302
|
+
)
|
|
1303
|
+
final_config.max_output_tokens = 8192
|
|
1304
|
+
continue
|
|
1305
|
+
break
|
|
1306
|
+
except Exception as e:
|
|
1307
|
+
self.logger.error(
|
|
1308
|
+
f"Error during generate_content: {e}"
|
|
1309
|
+
)
|
|
1310
|
+
if (retry_count + 1) >= max_retries:
|
|
1311
|
+
raise e
|
|
1312
|
+
retry_count += 1
|
|
1313
|
+
|
|
1314
|
+
# Handle function calls in stateless mode
|
|
1315
|
+
final_response = await self._handle_stateless_function_calls(
|
|
1316
|
+
response,
|
|
1317
|
+
model,
|
|
1318
|
+
contents,
|
|
1319
|
+
final_config,
|
|
1320
|
+
all_tool_calls,
|
|
1321
|
+
original_prompt=prompt
|
|
1322
|
+
)
|
|
1323
|
+
else:
|
|
1324
|
+
# MULTI-TURN CONVERSATION MODE
|
|
1325
|
+
chat = self.client.aio.chats.create(
|
|
1326
|
+
model=model,
|
|
1327
|
+
history=history
|
|
1328
|
+
)
|
|
1329
|
+
retry_count = 0
|
|
1330
|
+
# Send initial message
|
|
1331
|
+
while retry_count < max_retries:
|
|
1332
|
+
try:
|
|
1333
|
+
response = await chat.send_message(
|
|
1334
|
+
message=prompt,
|
|
1335
|
+
config=final_config
|
|
1336
|
+
)
|
|
1337
|
+
finish_reason = getattr(response.candidates[0], 'finish_reason', None)
|
|
1338
|
+
if finish_reason and finish_reason.name == "MAX_TOKENS" and generation_config["max_output_tokens"] <= 1024:
|
|
1339
|
+
retry_count += 1
|
|
1340
|
+
self.logger.warning(
|
|
1341
|
+
f"Hit MAX_TOKENS limit on initial response. Retrying {retry_count}/{max_retries} with increased token limit."
|
|
1342
|
+
)
|
|
1343
|
+
final_config.max_output_tokens = 8192
|
|
1344
|
+
continue
|
|
1345
|
+
break
|
|
1346
|
+
except Exception as e:
|
|
1347
|
+
# Handle specific network client error (socket/aiohttp issue)
|
|
1348
|
+
if "'NoneType' object has no attribute 'getaddrinfo'" in str(e):
|
|
1349
|
+
self.logger.warning(
|
|
1350
|
+
f"Encountered network client error: {e}. Resetting client and retrying."
|
|
1351
|
+
)
|
|
1352
|
+
# Reset the client
|
|
1353
|
+
self.client = None
|
|
1354
|
+
if not self.client:
|
|
1355
|
+
self.client = await self.get_client()
|
|
1356
|
+
# Recreate the chat session
|
|
1357
|
+
chat = self.client.aio.chats.create(
|
|
1358
|
+
model=model,
|
|
1359
|
+
history=history
|
|
1360
|
+
)
|
|
1361
|
+
retry_count += 1
|
|
1362
|
+
continue
|
|
1363
|
+
|
|
1364
|
+
self.logger.error(
|
|
1365
|
+
f"Error during initial chat.send_message: {e}"
|
|
1366
|
+
)
|
|
1367
|
+
if (retry_count + 1) >= max_retries:
|
|
1368
|
+
raise e
|
|
1369
|
+
retry_count += 1
|
|
1370
|
+
|
|
1371
|
+
has_function_calls = False
|
|
1372
|
+
if response and getattr(response, "candidates", None):
|
|
1373
|
+
candidate = response.candidates[0] if response.candidates else None
|
|
1374
|
+
content = getattr(candidate, "content", None) if candidate else None
|
|
1375
|
+
parts = getattr(content, "parts", None) if content else None
|
|
1376
|
+
has_function_calls = bool(parts)
|
|
1377
|
+
|
|
1378
|
+
self.logger.debug(
|
|
1379
|
+
f"Initial response has function calls: {has_function_calls}"
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
# Multi-turn function calling loop
|
|
1383
|
+
final_response = await self._handle_multiturn_function_calls(
|
|
1384
|
+
chat,
|
|
1385
|
+
response,
|
|
1386
|
+
all_tool_calls,
|
|
1387
|
+
original_prompt=original_prompt,
|
|
1388
|
+
model=model,
|
|
1389
|
+
max_iterations=10,
|
|
1390
|
+
config=final_config,
|
|
1391
|
+
max_retries=max_retries,
|
|
1392
|
+
lazy_loading=lazy_loading,
|
|
1393
|
+
active_tool_names=active_tool_names
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
# Extract assistant response text for conversation memory
|
|
1397
|
+
assistant_response_text = self._safe_extract_text(final_response)
|
|
1398
|
+
|
|
1399
|
+
# If we still don't have text but have tool calls, generate a summary
|
|
1400
|
+
if not assistant_response_text and all_tool_calls:
|
|
1401
|
+
assistant_response_text = self._create_simple_summary(
|
|
1402
|
+
all_tool_calls
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
# Handle structured output
|
|
1406
|
+
final_output = None
|
|
1407
|
+
if structured_output_for_later and use_tools and assistant_response_text:
|
|
1408
|
+
try:
|
|
1409
|
+
# Create a new generation config for structured output only
|
|
1410
|
+
structured_config = {
|
|
1411
|
+
"max_output_tokens": max_tokens or self.max_tokens,
|
|
1412
|
+
"temperature": temperature or self.temperature,
|
|
1413
|
+
"response_mime_type": "application/json"
|
|
1414
|
+
}
|
|
1415
|
+
# Set the schema based on the type of structured output
|
|
1416
|
+
schema_config = (
|
|
1417
|
+
structured_output_for_later
|
|
1418
|
+
if isinstance(structured_output_for_later, StructuredOutputConfig)
|
|
1419
|
+
else self._get_structured_config(structured_output_for_later)
|
|
1420
|
+
)
|
|
1421
|
+
if schema_config:
|
|
1422
|
+
self._apply_structured_output_schema(structured_config, schema_config)
|
|
1423
|
+
# Create a new client call without tools for structured output
|
|
1424
|
+
format_prompt = (
|
|
1425
|
+
f"Please format the following information according to the requested JSON structure. "
|
|
1426
|
+
f"Return only the JSON object with the requested fields:\n\n{assistant_response_text}"
|
|
1427
|
+
)
|
|
1428
|
+
structured_response = await self.client.aio.models.generate_content(
|
|
1429
|
+
model=model,
|
|
1430
|
+
contents=[{"role": "user", "parts": [{"text": format_prompt}]}],
|
|
1431
|
+
config=GenerateContentConfig(**structured_config)
|
|
1432
|
+
)
|
|
1433
|
+
# Extract structured text
|
|
1434
|
+
if structured_text := self._safe_extract_text(structured_response):
|
|
1435
|
+
# Parse the structured output
|
|
1436
|
+
if isinstance(structured_output_for_later, StructuredOutputConfig):
|
|
1437
|
+
final_output = await self._parse_structured_output(
|
|
1438
|
+
structured_text,
|
|
1439
|
+
structured_output_for_later
|
|
1440
|
+
)
|
|
1441
|
+
elif isinstance(structured_output_for_later, type):
|
|
1442
|
+
if hasattr(structured_output_for_later, 'model_validate_json'):
|
|
1443
|
+
final_output = structured_output_for_later.model_validate_json(structured_text)
|
|
1444
|
+
elif hasattr(structured_output_for_later, 'model_validate'):
|
|
1445
|
+
parsed_json = self._json.loads(structured_text)
|
|
1446
|
+
final_output = structured_output_for_later.model_validate(parsed_json)
|
|
1447
|
+
else:
|
|
1448
|
+
final_output = self._json.loads(structured_text)
|
|
1449
|
+
else:
|
|
1450
|
+
final_output = self._json.loads(structured_text)
|
|
1451
|
+
# # --- Fallback Logic ---
|
|
1452
|
+
# is_json_format = (
|
|
1453
|
+
# isinstance(structured_output_for_later, StructuredOutputConfig) and
|
|
1454
|
+
# structured_output_for_later.format == OutputFormat.JSON
|
|
1455
|
+
# )
|
|
1456
|
+
# if is_json_format and isinstance(final_output, str):
|
|
1457
|
+
# try:
|
|
1458
|
+
# self._json.loads(final_output)
|
|
1459
|
+
# except Exception:
|
|
1460
|
+
# self.logger.warning(
|
|
1461
|
+
# "Structured output re-formatting resulted in invalid/truncated JSON. "
|
|
1462
|
+
# "Falling back to original tool output."
|
|
1463
|
+
# )
|
|
1464
|
+
# final_output = assistant_response_text
|
|
1465
|
+
else:
|
|
1466
|
+
self.logger.warning(
|
|
1467
|
+
"No structured text received, falling back to original response"
|
|
1468
|
+
)
|
|
1469
|
+
final_output = assistant_response_text
|
|
1470
|
+
except Exception as e:
|
|
1471
|
+
self.logger.error(f"Error parsing structured output: {e}")
|
|
1472
|
+
# Fallback to original text if structured output fails
|
|
1473
|
+
final_output = assistant_response_text
|
|
1474
|
+
elif output_config and not use_tools:
|
|
1475
|
+
try:
|
|
1476
|
+
final_output = await self._parse_structured_output(
|
|
1477
|
+
assistant_response_text,
|
|
1478
|
+
output_config
|
|
1479
|
+
)
|
|
1480
|
+
except Exception:
|
|
1481
|
+
final_output = assistant_response_text
|
|
1482
|
+
else:
|
|
1483
|
+
final_output = assistant_response_text
|
|
1484
|
+
|
|
1485
|
+
# Update conversation memory with the final response
|
|
1486
|
+
final_assistant_message = {
|
|
1487
|
+
"role": "model",
|
|
1488
|
+
"content": [
|
|
1489
|
+
{
|
|
1490
|
+
"type": "text",
|
|
1491
|
+
"text": str(final_output) if final_output != assistant_response_text else assistant_response_text
|
|
1492
|
+
}
|
|
1493
|
+
]
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
# Update conversation memory with unified system
|
|
1497
|
+
if not stateless and conversation_history:
|
|
1498
|
+
tools_used = [tc.name for tc in all_tool_calls]
|
|
1499
|
+
await self._update_conversation_memory(
|
|
1500
|
+
user_id,
|
|
1501
|
+
session_id,
|
|
1502
|
+
conversation_history,
|
|
1503
|
+
messages + [final_assistant_message],
|
|
1504
|
+
system_prompt,
|
|
1505
|
+
turn_id,
|
|
1506
|
+
original_prompt,
|
|
1507
|
+
assistant_response_text,
|
|
1508
|
+
tools_used
|
|
1509
|
+
)
|
|
1510
|
+
# Create AIMessage using factory
|
|
1511
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
1512
|
+
response=response,
|
|
1513
|
+
input_text=original_prompt,
|
|
1514
|
+
model=model,
|
|
1515
|
+
user_id=user_id,
|
|
1516
|
+
session_id=session_id,
|
|
1517
|
+
turn_id=turn_id,
|
|
1518
|
+
structured_output=final_output,
|
|
1519
|
+
tool_calls=all_tool_calls,
|
|
1520
|
+
conversation_history=conversation_history,
|
|
1521
|
+
text_response=assistant_response_text
|
|
1522
|
+
)
|
|
1523
|
+
|
|
1524
|
+
# Override provider to distinguish from Vertex AI
|
|
1525
|
+
ai_message.provider = "google_genai"
|
|
1526
|
+
|
|
1527
|
+
return ai_message
|
|
1528
|
+
|
|
1529
|
+
def _create_simple_summary(self, all_tool_calls: List[ToolCall]) -> str:
|
|
1530
|
+
"""Create a simple summary from tool calls."""
|
|
1531
|
+
if not all_tool_calls:
|
|
1532
|
+
return "Task completed."
|
|
1533
|
+
|
|
1534
|
+
if len(all_tool_calls) == 1:
|
|
1535
|
+
tc = all_tool_calls[0]
|
|
1536
|
+
if isinstance(tc.result, Exception):
|
|
1537
|
+
return f"Tool {tc.name} failed with error: {tc.result}"
|
|
1538
|
+
elif isinstance(tc.result, pd.DataFrame):
|
|
1539
|
+
if not tc.result.empty:
|
|
1540
|
+
return f"Tool {tc.name} returned a DataFrame with {len(tc.result)} rows."
|
|
1541
|
+
else:
|
|
1542
|
+
return f"Tool {tc.name} returned an empty DataFrame."
|
|
1543
|
+
elif tc.result and isinstance(tc.result, dict) and 'expression' in tc.result:
|
|
1544
|
+
return tc.result['expression']
|
|
1545
|
+
elif tc.result and isinstance(tc.result, dict) and 'result' in tc.result:
|
|
1546
|
+
return f"Result: {tc.result['result']}"
|
|
1547
|
+
else:
|
|
1548
|
+
# Multiple calls - show the final result
|
|
1549
|
+
final_tc = all_tool_calls[-1]
|
|
1550
|
+
if isinstance(final_tc.result, pd.DataFrame):
|
|
1551
|
+
if not final_tc.result.empty:
|
|
1552
|
+
return f"Data: {final_tc.result.to_string()}"
|
|
1553
|
+
else:
|
|
1554
|
+
return f"Final tool {final_tc.name} returned an empty DataFrame."
|
|
1555
|
+
if final_tc.result and isinstance(final_tc.result, dict):
|
|
1556
|
+
if 'result' in final_tc.result:
|
|
1557
|
+
return f"Final result: {final_tc.result['result']}"
|
|
1558
|
+
elif 'expression' in final_tc.result:
|
|
1559
|
+
return final_tc.result['expression']
|
|
1560
|
+
|
|
1561
|
+
return "Calculation completed."
|
|
1562
|
+
|
|
1563
|
+
def _build_function_declarations(self) -> List[types.FunctionDeclaration]:
|
|
1564
|
+
"""Build function declarations for Google GenAI tools."""
|
|
1565
|
+
function_declarations = []
|
|
1566
|
+
|
|
1567
|
+
for tool in self.tool_manager.all_tools():
|
|
1568
|
+
tool_name = tool.name
|
|
1569
|
+
|
|
1570
|
+
if isinstance(tool, AbstractTool):
|
|
1571
|
+
full_schema = tool.get_tool_schema()
|
|
1572
|
+
tool_description = full_schema.get("description", tool.description)
|
|
1573
|
+
schema = full_schema.get("parameters", {}).copy()
|
|
1574
|
+
schema = self.clean_google_schema(schema)
|
|
1575
|
+
elif isinstance(tool, ToolDefinition):
|
|
1576
|
+
tool_description = tool.description
|
|
1577
|
+
schema = self.clean_google_schema(tool.input_schema.copy())
|
|
1578
|
+
else:
|
|
1579
|
+
tool_description = getattr(tool, 'description', f"Tool: {tool_name}")
|
|
1580
|
+
schema = getattr(tool, 'input_schema', {})
|
|
1581
|
+
schema = self.clean_google_schema(schema)
|
|
1582
|
+
|
|
1583
|
+
if not schema:
|
|
1584
|
+
schema = {"type": "object", "properties": {}, "required": []}
|
|
1585
|
+
|
|
1586
|
+
try:
|
|
1587
|
+
declaration = types.FunctionDeclaration(
|
|
1588
|
+
name=tool_name,
|
|
1589
|
+
description=tool_description,
|
|
1590
|
+
parameters=self._fix_tool_schema(schema)
|
|
1591
|
+
)
|
|
1592
|
+
function_declarations.append(declaration)
|
|
1593
|
+
except Exception as e:
|
|
1594
|
+
self.logger.error(f"Error creating {tool_name}: {e}")
|
|
1595
|
+
continue
|
|
1596
|
+
|
|
1597
|
+
return function_declarations
|
|
1598
|
+
|
|
1599
|
+
async def ask_stream(
|
|
1600
|
+
self,
|
|
1601
|
+
prompt: str,
|
|
1602
|
+
model: Union[str, GoogleModel] = None,
|
|
1603
|
+
max_tokens: Optional[int] = None,
|
|
1604
|
+
temperature: Optional[float] = None,
|
|
1605
|
+
files: Optional[List[Union[str, Path]]] = None,
|
|
1606
|
+
system_prompt: Optional[str] = None,
|
|
1607
|
+
user_id: Optional[str] = None,
|
|
1608
|
+
session_id: Optional[str] = None,
|
|
1609
|
+
retry_config: Optional[StreamingRetryConfig] = None,
|
|
1610
|
+
on_max_tokens: Optional[str] = "retry", # "retry", "notify", "ignore"
|
|
1611
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
1612
|
+
use_tools: Optional[bool] = None,
|
|
1613
|
+
deep_research: bool = False,
|
|
1614
|
+
agent_config: Optional[Dict[str, Any]] = None,
|
|
1615
|
+
lazy_loading: bool = False,
|
|
1616
|
+
) -> AsyncIterator[str]:
|
|
1617
|
+
"""
|
|
1618
|
+
Stream Google Generative AI's response using AsyncIterator with support for Tool Calling.
|
|
1619
|
+
|
|
1620
|
+
Args:
|
|
1621
|
+
on_max_tokens: How to handle MAX_TOKENS finish reason:
|
|
1622
|
+
- "retry": Automatically retry with increased token limit
|
|
1623
|
+
- "notify": Yield a notification message and continue
|
|
1624
|
+
- "ignore": Silently continue (original behavior)
|
|
1625
|
+
deep_research: If True, use Google's deep research agent (stream mode)
|
|
1626
|
+
agent_config: Optional configuration for deep research (e.g., thinking_summaries)
|
|
1627
|
+
"""
|
|
1628
|
+
model = (
|
|
1629
|
+
model.value if isinstance(model, GoogleModel) else model
|
|
1630
|
+
) or (self.model or GoogleModel.GEMINI_2_5_FLASH.value)
|
|
1631
|
+
|
|
1632
|
+
# Handle case where model is passed as a tuple or list
|
|
1633
|
+
if isinstance(model, (list, tuple)):
|
|
1634
|
+
model = model[0]
|
|
1635
|
+
|
|
1636
|
+
# Stub for deep research streaming
|
|
1637
|
+
if deep_research:
|
|
1638
|
+
self.logger.warning(
|
|
1639
|
+
"Google Deep Research streaming is not yet fully implemented. "
|
|
1640
|
+
"Falling back to standard ask_stream() behavior."
|
|
1641
|
+
)
|
|
1642
|
+
# TODO: Implement interactions.create(stream=True) when SDK supports it
|
|
1643
|
+
# For now, just use regular streaming
|
|
1644
|
+
|
|
1645
|
+
turn_id = str(uuid.uuid4())
|
|
1646
|
+
# Default retry configuration
|
|
1647
|
+
if retry_config is None:
|
|
1648
|
+
retry_config = StreamingRetryConfig()
|
|
1649
|
+
|
|
1650
|
+
# Use the unified conversation context preparation from AbstractClient
|
|
1651
|
+
messages, conversation_history, system_prompt = await self._prepare_conversation_context(
|
|
1652
|
+
prompt, files, user_id, session_id, system_prompt
|
|
1653
|
+
)
|
|
1654
|
+
|
|
1655
|
+
# Prepare conversation history for Google GenAI format
|
|
1656
|
+
history = []
|
|
1657
|
+
if messages:
|
|
1658
|
+
for msg in messages[:-1]: # Exclude the current user message (last in list)
|
|
1659
|
+
role = msg['role'].lower()
|
|
1660
|
+
if role == 'user':
|
|
1661
|
+
parts = []
|
|
1662
|
+
for part_content in msg.get('content', []):
|
|
1663
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1664
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1665
|
+
if parts:
|
|
1666
|
+
history.append(UserContent(parts=parts))
|
|
1667
|
+
elif role in ['assistant', 'model']:
|
|
1668
|
+
parts = []
|
|
1669
|
+
for part_content in msg.get('content', []):
|
|
1670
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1671
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1672
|
+
if parts:
|
|
1673
|
+
history.append(ModelContent(parts=parts))
|
|
1674
|
+
|
|
1675
|
+
# --- Tool Configuration (Mirrored from ask method) ---
|
|
1676
|
+
_use_tools = use_tools if use_tools is not None else self.enable_tools
|
|
1677
|
+
|
|
1678
|
+
# Register requested tools if any
|
|
1679
|
+
if tools and isinstance(tools, list):
|
|
1680
|
+
for tool in tools:
|
|
1681
|
+
self.register_tool(tool)
|
|
1682
|
+
|
|
1683
|
+
# Determine tool strategy
|
|
1684
|
+
if _use_tools:
|
|
1685
|
+
# If explicit tools passed or just enabled, force low temp
|
|
1686
|
+
temperature = 0 if temperature is None else temperature
|
|
1687
|
+
tool_type = "custom_functions"
|
|
1688
|
+
elif _use_tools is None:
|
|
1689
|
+
# Analyze prompt
|
|
1690
|
+
tool_type = self._analyze_prompt_for_tools(prompt)
|
|
1691
|
+
else:
|
|
1692
|
+
tool_type = 'builtin_tools' if _use_tools else None
|
|
1693
|
+
|
|
1694
|
+
# Build the actual tool objects for Gemini
|
|
1695
|
+
gemini_tools = self._build_tools(tool_type) if tool_type else []
|
|
1696
|
+
|
|
1697
|
+
if _use_tools and tool_type == "custom_functions" and not gemini_tools:
|
|
1698
|
+
# Fallback if no tools registered
|
|
1699
|
+
gemini_tools = None
|
|
1700
|
+
|
|
1701
|
+
# --- Execution Loop ---
|
|
1702
|
+
|
|
1703
|
+
# Retry loop variables
|
|
1704
|
+
current_max_tokens = max_tokens or self.max_tokens
|
|
1705
|
+
retry_count = 0
|
|
1706
|
+
|
|
1707
|
+
# Variables for multi-turn tool loop
|
|
1708
|
+
current_message_content = prompt # Start with the user prompt
|
|
1709
|
+
keep_looping = True
|
|
1710
|
+
|
|
1711
|
+
# Start the chat session once
|
|
1712
|
+
chat = self.client.aio.chats.create(
|
|
1713
|
+
model=model,
|
|
1714
|
+
history=history,
|
|
1715
|
+
config=GenerateContentConfig(
|
|
1716
|
+
system_instruction=system_prompt,
|
|
1717
|
+
tools=gemini_tools,
|
|
1718
|
+
temperature=temperature or self.temperature,
|
|
1719
|
+
max_output_tokens=current_max_tokens
|
|
1720
|
+
)
|
|
1721
|
+
)
|
|
1722
|
+
|
|
1723
|
+
all_assistant_text = [] # Keep track of full text for memory update
|
|
1724
|
+
|
|
1725
|
+
while keep_looping and retry_count <= retry_config.max_retries:
|
|
1726
|
+
# By default, we stop after one turn unless a tool is called
|
|
1727
|
+
keep_looping = False
|
|
1728
|
+
|
|
1729
|
+
try:
|
|
1730
|
+
# If we are retrying due to max tokens, update config
|
|
1731
|
+
chat._config.max_output_tokens = current_max_tokens
|
|
1732
|
+
|
|
1733
|
+
assistant_content_chunk = ""
|
|
1734
|
+
max_tokens_reached = False
|
|
1735
|
+
|
|
1736
|
+
# We need to capture function calls from the chunks as they arrive
|
|
1737
|
+
collected_function_calls = []
|
|
1738
|
+
|
|
1739
|
+
async for chunk in await chat.send_message_stream(current_message_content):
|
|
1740
|
+
# Check for MAX_TOKENS finish reason
|
|
1741
|
+
if (hasattr(chunk, 'candidates') and chunk.candidates and len(chunk.candidates) > 0):
|
|
1742
|
+
candidate = chunk.candidates[0]
|
|
1743
|
+
if (hasattr(candidate, 'finish_reason') and
|
|
1744
|
+
str(candidate.finish_reason) == 'FinishReason.MAX_TOKENS'):
|
|
1745
|
+
max_tokens_reached = True
|
|
1746
|
+
|
|
1747
|
+
if on_max_tokens == "notify":
|
|
1748
|
+
yield f"\n\n⚠️ **Response truncated due to token limit ({current_max_tokens} tokens).**\n"
|
|
1749
|
+
elif on_max_tokens == "retry" and retry_config.auto_retry_on_max_tokens:
|
|
1750
|
+
# Break inner loop to handle retry in outer loop
|
|
1751
|
+
break
|
|
1752
|
+
|
|
1753
|
+
# Capture function calls from the chunk
|
|
1754
|
+
if (hasattr(chunk, 'candidates') and chunk.candidates):
|
|
1755
|
+
for candidate in chunk.candidates:
|
|
1756
|
+
if hasattr(candidate, 'content') and candidate.content and candidate.content.parts:
|
|
1757
|
+
for part in candidate.content.parts:
|
|
1758
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
1759
|
+
collected_function_calls.append(part.function_call)
|
|
1760
|
+
|
|
1761
|
+
# Yield text content if present
|
|
1762
|
+
if chunk.text:
|
|
1763
|
+
assistant_content_chunk += chunk.text
|
|
1764
|
+
all_assistant_text.append(chunk.text)
|
|
1765
|
+
yield chunk.text
|
|
1766
|
+
|
|
1767
|
+
# --- Handle Max Tokens Retry ---
|
|
1768
|
+
if max_tokens_reached and on_max_tokens == "retry" and retry_config.auto_retry_on_max_tokens:
|
|
1769
|
+
if retry_count < retry_config.max_retries:
|
|
1770
|
+
new_max_tokens = int(current_max_tokens * retry_config.token_increase_factor)
|
|
1771
|
+
yield f"\n\n🔄 **Retrying with increased limit ({new_max_tokens})...**\n\n"
|
|
1772
|
+
current_max_tokens = new_max_tokens
|
|
1773
|
+
retry_count += 1
|
|
1774
|
+
await self._wait_with_backoff(retry_count, retry_config)
|
|
1775
|
+
keep_looping = True # Force loop to continue
|
|
1776
|
+
continue
|
|
1777
|
+
else:
|
|
1778
|
+
yield f"\n\n❌ **Maximum retries reached.**\n"
|
|
1779
|
+
|
|
1780
|
+
# --- Handle Function Calls ---
|
|
1781
|
+
if collected_function_calls:
|
|
1782
|
+
# We have tool calls to execute!
|
|
1783
|
+
self.logger.info(f"Streaming detected {len(collected_function_calls)} tool calls.")
|
|
1784
|
+
|
|
1785
|
+
# Execute tools (parallel)
|
|
1786
|
+
tool_execution_tasks = [
|
|
1787
|
+
self._execute_tool(fc.name, dict(fc.args))
|
|
1788
|
+
for fc in collected_function_calls
|
|
1789
|
+
]
|
|
1790
|
+
tool_results = await asyncio.gather(*tool_execution_tasks, return_exceptions=True)
|
|
1791
|
+
|
|
1792
|
+
# Build the response parts containing tool outputs
|
|
1793
|
+
function_response_parts = []
|
|
1794
|
+
for fc, result in zip(collected_function_calls, tool_results):
|
|
1795
|
+
response_content = self._process_tool_result_for_api(result)
|
|
1796
|
+
function_response_parts.append(
|
|
1797
|
+
Part(
|
|
1798
|
+
function_response=types.FunctionResponse(
|
|
1799
|
+
name=fc.name,
|
|
1800
|
+
response=response_content
|
|
1801
|
+
)
|
|
1802
|
+
)
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
# Set the next message to be these tool outputs
|
|
1806
|
+
current_message_content = function_response_parts
|
|
1807
|
+
|
|
1808
|
+
# Force the loop to run again to stream the answer based on these tools
|
|
1809
|
+
keep_looping = True
|
|
1810
|
+
|
|
1811
|
+
except Exception as e:
|
|
1812
|
+
# Handle specific network client error
|
|
1813
|
+
if "'NoneType' object has no attribute 'getaddrinfo'" in str(e):
|
|
1814
|
+
if retry_count < retry_config.max_retries:
|
|
1815
|
+
self.logger.warning(
|
|
1816
|
+
f"Encountered network client error during stream: {e}. Resetting client..."
|
|
1817
|
+
)
|
|
1818
|
+
self.client = None
|
|
1819
|
+
if not self.client:
|
|
1820
|
+
self.client = await self.get_client()
|
|
1821
|
+
|
|
1822
|
+
# Recreate chat session
|
|
1823
|
+
# Note: We rely on history variable being the initial history.
|
|
1824
|
+
# Intermediate turn state might be lost if this happens mid-conversation,
|
|
1825
|
+
# but this error usually happens at connection start.
|
|
1826
|
+
chat = self.client.aio.chats.create(
|
|
1827
|
+
model=model,
|
|
1828
|
+
history=history,
|
|
1829
|
+
config=GenerateContentConfig(
|
|
1830
|
+
system_instruction=system_prompt,
|
|
1831
|
+
tools=gemini_tools,
|
|
1832
|
+
temperature=temperature or self.temperature,
|
|
1833
|
+
max_output_tokens=current_max_tokens
|
|
1834
|
+
)
|
|
1835
|
+
)
|
|
1836
|
+
retry_count += 1
|
|
1837
|
+
await self._wait_with_backoff(retry_count, retry_config)
|
|
1838
|
+
keep_looping = True
|
|
1839
|
+
continue
|
|
1840
|
+
|
|
1841
|
+
if retry_count < retry_config.max_retries:
|
|
1842
|
+
error_msg = f"\n\n⚠️ **Streaming error (attempt {retry_count + 1}): {str(e)}. Retrying...**\n\n"
|
|
1843
|
+
yield error_msg
|
|
1844
|
+
retry_count += 1
|
|
1845
|
+
await self._wait_with_backoff(retry_count, retry_config)
|
|
1846
|
+
keep_looping = True
|
|
1847
|
+
continue
|
|
1848
|
+
else:
|
|
1849
|
+
yield f"\n\n❌ **Streaming failed: {str(e)}**\n"
|
|
1850
|
+
break
|
|
1851
|
+
|
|
1852
|
+
# Update conversation memory
|
|
1853
|
+
final_text = "".join(all_assistant_text)
|
|
1854
|
+
if final_text:
|
|
1855
|
+
final_assistant_message = {
|
|
1856
|
+
"role": "assistant", "content": [
|
|
1857
|
+
{"type": "text", "text": final_text}
|
|
1858
|
+
]
|
|
1859
|
+
}
|
|
1860
|
+
# Extract assistant response text for conversation memory
|
|
1861
|
+
await self._update_conversation_memory(
|
|
1862
|
+
user_id,
|
|
1863
|
+
session_id,
|
|
1864
|
+
conversation_history,
|
|
1865
|
+
messages + [final_assistant_message],
|
|
1866
|
+
system_prompt,
|
|
1867
|
+
turn_id,
|
|
1868
|
+
prompt,
|
|
1869
|
+
final_text,
|
|
1870
|
+
[] # We don't easily track tool usage in stream return yet, or we could track in loop
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
async def batch_ask(self, requests) -> List[AIMessage]:
|
|
1874
|
+
"""Process multiple requests in batch."""
|
|
1875
|
+
# Google GenAI doesn't have a native batch API, so we process sequentially
|
|
1876
|
+
results = []
|
|
1877
|
+
for request in requests:
|
|
1878
|
+
result = await self.ask(**request)
|
|
1879
|
+
results.append(result)
|
|
1880
|
+
return results
|
|
1881
|
+
|
|
1882
|
+
async def ask_to_image(
|
|
1883
|
+
self,
|
|
1884
|
+
prompt: str,
|
|
1885
|
+
image: Union[Path, bytes],
|
|
1886
|
+
reference_images: Optional[Union[List[Path], List[bytes]]] = None,
|
|
1887
|
+
model: Union[str, GoogleModel] = None,
|
|
1888
|
+
max_tokens: Optional[int] = None,
|
|
1889
|
+
temperature: Optional[float] = None,
|
|
1890
|
+
structured_output: Union[type, StructuredOutputConfig] = None,
|
|
1891
|
+
count_objects: bool = False,
|
|
1892
|
+
user_id: Optional[str] = None,
|
|
1893
|
+
session_id: Optional[str] = None,
|
|
1894
|
+
no_memory: bool = False,
|
|
1895
|
+
) -> AIMessage:
|
|
1896
|
+
"""
|
|
1897
|
+
Ask a question to Google's Generative AI using a stateful chat session.
|
|
1898
|
+
"""
|
|
1899
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
1900
|
+
if not model:
|
|
1901
|
+
model = self.model or GoogleModel.GEMINI_2_5_FLASH.value
|
|
1902
|
+
turn_id = str(uuid.uuid4())
|
|
1903
|
+
original_prompt = prompt
|
|
1904
|
+
|
|
1905
|
+
if no_memory:
|
|
1906
|
+
# For no_memory mode, skip conversation memory
|
|
1907
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
1908
|
+
conversation_session = None
|
|
1909
|
+
else:
|
|
1910
|
+
messages, conversation_session, _ = await self._prepare_conversation_context(
|
|
1911
|
+
prompt, None, user_id, session_id, None
|
|
1912
|
+
)
|
|
1913
|
+
|
|
1914
|
+
# Prepare conversation history for Google GenAI format
|
|
1915
|
+
history = []
|
|
1916
|
+
if messages:
|
|
1917
|
+
for msg in messages[:-1]: # Exclude the current user message (last in list)
|
|
1918
|
+
role = msg['role'].lower()
|
|
1919
|
+
if role == 'user':
|
|
1920
|
+
parts = []
|
|
1921
|
+
for part_content in msg.get('content', []):
|
|
1922
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1923
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1924
|
+
if parts:
|
|
1925
|
+
history.append(UserContent(parts=parts))
|
|
1926
|
+
elif role in ['assistant', 'model']:
|
|
1927
|
+
parts = []
|
|
1928
|
+
for part_content in msg.get('content', []):
|
|
1929
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
1930
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
1931
|
+
if parts:
|
|
1932
|
+
history.append(ModelContent(parts=parts))
|
|
1933
|
+
|
|
1934
|
+
# --- Multi-Modal Content Preparation ---
|
|
1935
|
+
if isinstance(image, Path):
|
|
1936
|
+
if not image.exists():
|
|
1937
|
+
raise FileNotFoundError(
|
|
1938
|
+
f"Image file not found: {image}"
|
|
1939
|
+
)
|
|
1940
|
+
# Load the primary image
|
|
1941
|
+
primary_image = Image.open(image)
|
|
1942
|
+
elif isinstance(image, bytes):
|
|
1943
|
+
primary_image = Image.open(io.BytesIO(image))
|
|
1944
|
+
elif isinstance(image, Image.Image):
|
|
1945
|
+
primary_image = image
|
|
1946
|
+
else:
|
|
1947
|
+
raise ValueError(
|
|
1948
|
+
"Image must be a Path, bytes, or PIL.Image object."
|
|
1949
|
+
)
|
|
1950
|
+
|
|
1951
|
+
# The content for the API call is a list containing images and the final prompt
|
|
1952
|
+
contents = [primary_image]
|
|
1953
|
+
if reference_images:
|
|
1954
|
+
for ref_path in reference_images:
|
|
1955
|
+
self.logger.debug(
|
|
1956
|
+
f"Loading reference image from: {ref_path}"
|
|
1957
|
+
)
|
|
1958
|
+
if isinstance(ref_path, Path):
|
|
1959
|
+
if not ref_path.exists():
|
|
1960
|
+
raise FileNotFoundError(
|
|
1961
|
+
f"Reference image file not found: {ref_path}"
|
|
1962
|
+
)
|
|
1963
|
+
contents.append(Image.open(ref_path))
|
|
1964
|
+
elif isinstance(ref_path, bytes):
|
|
1965
|
+
contents.append(Image.open(io.BytesIO(ref_path)))
|
|
1966
|
+
elif isinstance(ref_path, Image.Image):
|
|
1967
|
+
# is already a PIL.Image Object
|
|
1968
|
+
contents.append(ref_path)
|
|
1969
|
+
else:
|
|
1970
|
+
raise ValueError(
|
|
1971
|
+
"Reference Image must be a Path, bytes, or PIL.Image object."
|
|
1972
|
+
)
|
|
1973
|
+
|
|
1974
|
+
contents.append(prompt) # The text prompt always comes last
|
|
1975
|
+
generation_config = {
|
|
1976
|
+
"max_output_tokens": max_tokens or self.max_tokens,
|
|
1977
|
+
"temperature": temperature or self.temperature,
|
|
1978
|
+
}
|
|
1979
|
+
output_config = self._get_structured_config(structured_output)
|
|
1980
|
+
structured_output_config = output_config
|
|
1981
|
+
# Vision models generally don't support tools, so we focus on structured output
|
|
1982
|
+
if structured_output_config:
|
|
1983
|
+
self.logger.debug("Structured output requested for vision task.")
|
|
1984
|
+
self._apply_structured_output_schema(generation_config, structured_output_config)
|
|
1985
|
+
elif count_objects:
|
|
1986
|
+
# Default to JSON for structured output if not specified
|
|
1987
|
+
structured_output_config = StructuredOutputConfig(output_type=ObjectDetectionResult)
|
|
1988
|
+
self._apply_structured_output_schema(generation_config, structured_output_config)
|
|
1989
|
+
|
|
1990
|
+
# Create the stateful chat session
|
|
1991
|
+
chat = self.client.aio.chats.create(model=model, history=history)
|
|
1992
|
+
final_config = GenerateContentConfig(**generation_config)
|
|
1993
|
+
|
|
1994
|
+
# Make the primary multi-modal call
|
|
1995
|
+
self.logger.debug(f"Sending {len(contents)} parts to the model.")
|
|
1996
|
+
response = await chat.send_message(
|
|
1997
|
+
message=contents,
|
|
1998
|
+
config=final_config
|
|
1999
|
+
)
|
|
2000
|
+
|
|
2001
|
+
# --- Response Handling ---
|
|
2002
|
+
final_output = None
|
|
2003
|
+
if structured_output_config:
|
|
2004
|
+
try:
|
|
2005
|
+
final_output = await self._parse_structured_output(
|
|
2006
|
+
response.text,
|
|
2007
|
+
structured_output_config
|
|
2008
|
+
)
|
|
2009
|
+
except Exception as e:
|
|
2010
|
+
self.logger.error(
|
|
2011
|
+
f"Failed to parse structured output from vision model: {e}"
|
|
2012
|
+
)
|
|
2013
|
+
final_output = response.text
|
|
2014
|
+
elif '```json' in response.text:
|
|
2015
|
+
# Attempt to extract JSON from markdown code block
|
|
2016
|
+
try:
|
|
2017
|
+
final_output = self._parse_json_from_text(response.text)
|
|
2018
|
+
except Exception as e:
|
|
2019
|
+
self.logger.error(
|
|
2020
|
+
f"Failed to parse JSON from markdown in vision model response: {e}"
|
|
2021
|
+
)
|
|
2022
|
+
final_output = response.text
|
|
2023
|
+
else:
|
|
2024
|
+
final_output = response.text
|
|
2025
|
+
|
|
2026
|
+
final_assistant_message = {
|
|
2027
|
+
"role": "model", "content": [
|
|
2028
|
+
{"type": "text", "text": final_output}
|
|
2029
|
+
]
|
|
2030
|
+
}
|
|
2031
|
+
if no_memory is False:
|
|
2032
|
+
await self._update_conversation_memory(
|
|
2033
|
+
user_id,
|
|
2034
|
+
session_id,
|
|
2035
|
+
conversation_session,
|
|
2036
|
+
messages + [
|
|
2037
|
+
{
|
|
2038
|
+
"role": "user",
|
|
2039
|
+
"content": [
|
|
2040
|
+
{"type": "text", "text": f"[Image Analysis]: {prompt}"}
|
|
2041
|
+
]
|
|
2042
|
+
},
|
|
2043
|
+
final_assistant_message
|
|
2044
|
+
],
|
|
2045
|
+
None,
|
|
2046
|
+
turn_id,
|
|
2047
|
+
original_prompt,
|
|
2048
|
+
response.text,
|
|
2049
|
+
[]
|
|
2050
|
+
)
|
|
2051
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
2052
|
+
response=response,
|
|
2053
|
+
input_text=original_prompt,
|
|
2054
|
+
model=model,
|
|
2055
|
+
user_id=user_id,
|
|
2056
|
+
session_id=session_id,
|
|
2057
|
+
turn_id=turn_id,
|
|
2058
|
+
structured_output=final_output if final_output != response.text else None,
|
|
2059
|
+
tool_calls=[]
|
|
2060
|
+
)
|
|
2061
|
+
ai_message.provider = "google_genai"
|
|
2062
|
+
return ai_message
|
|
2063
|
+
|
|
2064
|
+
async def generate_images(
|
|
2065
|
+
self,
|
|
2066
|
+
prompt_data: ImageGenerationPrompt,
|
|
2067
|
+
model: Union[str, GoogleModel] = GoogleModel.IMAGEN_3,
|
|
2068
|
+
reference_image: Optional[Path] = None,
|
|
2069
|
+
output_directory: Optional[Path] = None,
|
|
2070
|
+
mime_format: str = "image/jpeg",
|
|
2071
|
+
number_of_images: int = 1,
|
|
2072
|
+
user_id: Optional[str] = None,
|
|
2073
|
+
session_id: Optional[str] = None,
|
|
2074
|
+
add_watermark: bool = False
|
|
2075
|
+
) -> AIMessage:
|
|
2076
|
+
"""
|
|
2077
|
+
Generates images based on a text prompt using Imagen.
|
|
2078
|
+
"""
|
|
2079
|
+
if prompt_data.model:
|
|
2080
|
+
model = GoogleModel.IMAGEN_3.value
|
|
2081
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2082
|
+
self.logger.info(
|
|
2083
|
+
f"Starting image generation with model: {model}"
|
|
2084
|
+
)
|
|
2085
|
+
if model == GoogleModel.GEMINI_2_0_IMAGE_GENERATION.value:
|
|
2086
|
+
image_provider = "google_genai"
|
|
2087
|
+
config=types.GenerateContentConfig(
|
|
2088
|
+
response_modalities=['TEXT', 'IMAGE']
|
|
2089
|
+
)
|
|
2090
|
+
else:
|
|
2091
|
+
image_provider = "google_imagen"
|
|
2092
|
+
|
|
2093
|
+
full_prompt = prompt_data.prompt
|
|
2094
|
+
if prompt_data.styles:
|
|
2095
|
+
full_prompt += ", " + ", ".join(prompt_data.styles)
|
|
2096
|
+
|
|
2097
|
+
if reference_image:
|
|
2098
|
+
self.logger.info(
|
|
2099
|
+
f"Using reference image: {reference_image}"
|
|
2100
|
+
)
|
|
2101
|
+
if not reference_image.exists():
|
|
2102
|
+
raise FileNotFoundError(
|
|
2103
|
+
f"Reference image not found: {reference_image}"
|
|
2104
|
+
)
|
|
2105
|
+
# Load the reference image
|
|
2106
|
+
ref_image = Image.open(reference_image)
|
|
2107
|
+
full_prompt = [full_prompt, ref_image]
|
|
2108
|
+
|
|
2109
|
+
config = types.GenerateImagesConfig(
|
|
2110
|
+
number_of_images=number_of_images,
|
|
2111
|
+
output_mime_type=mime_format,
|
|
2112
|
+
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
|
2113
|
+
person_generation="ALLOW_ADULT", # Or ALLOW_ALL, etc.
|
|
2114
|
+
aspect_ratio=prompt_data.aspect_ratio,
|
|
2115
|
+
)
|
|
2116
|
+
|
|
2117
|
+
try:
|
|
2118
|
+
start_time = time.time()
|
|
2119
|
+
# Use the asynchronous client for image generation
|
|
2120
|
+
image_response = await self.client.aio.models.generate_images(
|
|
2121
|
+
model=prompt_data.model,
|
|
2122
|
+
prompt=full_prompt,
|
|
2123
|
+
config=config
|
|
2124
|
+
)
|
|
2125
|
+
execution_time = time.time() - start_time
|
|
2126
|
+
|
|
2127
|
+
pil_images = []
|
|
2128
|
+
saved_image_paths = []
|
|
2129
|
+
raw_response = {} # Initialize an empty dict for the raw response
|
|
2130
|
+
|
|
2131
|
+
if image_response.generated_images:
|
|
2132
|
+
self.logger.info(
|
|
2133
|
+
f"Successfully generated {len(image_response.generated_images)} image(s)."
|
|
2134
|
+
)
|
|
2135
|
+
raw_response['generated_images'] = []
|
|
2136
|
+
for i, generated_image in enumerate(image_response.generated_images):
|
|
2137
|
+
pil_image = generated_image.image
|
|
2138
|
+
pil_images.append(pil_image)
|
|
2139
|
+
|
|
2140
|
+
raw_response['generated_images'].append({
|
|
2141
|
+
'uri': getattr(generated_image, 'uri', None),
|
|
2142
|
+
'seed': getattr(generated_image, 'seed', None)
|
|
2143
|
+
})
|
|
2144
|
+
|
|
2145
|
+
if output_directory:
|
|
2146
|
+
file_path = self._save_image(pil_image, output_directory)
|
|
2147
|
+
saved_image_paths.append(file_path)
|
|
2148
|
+
|
|
2149
|
+
usage = CompletionUsage(execution_time=execution_time)
|
|
2150
|
+
# The primary 'output' is the list of raw PIL.Image objects
|
|
2151
|
+
# The new 'images' attribute holds the file paths
|
|
2152
|
+
ai_message = AIMessageFactory.from_imagen(
|
|
2153
|
+
output=pil_images,
|
|
2154
|
+
images=saved_image_paths,
|
|
2155
|
+
input=full_prompt,
|
|
2156
|
+
model=model,
|
|
2157
|
+
user_id=user_id,
|
|
2158
|
+
session_id=session_id,
|
|
2159
|
+
provider=image_provider,
|
|
2160
|
+
usage=usage,
|
|
2161
|
+
raw_response=raw_response
|
|
2162
|
+
)
|
|
2163
|
+
return ai_message
|
|
2164
|
+
|
|
2165
|
+
except Exception as e:
|
|
2166
|
+
self.logger.error(f"Image generation failed: {e}")
|
|
2167
|
+
raise
|
|
2168
|
+
|
|
2169
|
+
def _find_voice_for_speaker(self, speaker: FictionalSpeaker) -> str:
|
|
2170
|
+
"""
|
|
2171
|
+
Find the best voice for a speaker based on their characteristics and gender.
|
|
2172
|
+
|
|
2173
|
+
Args:
|
|
2174
|
+
speaker: The fictional speaker configuration
|
|
2175
|
+
|
|
2176
|
+
Returns:
|
|
2177
|
+
Voice name string
|
|
2178
|
+
"""
|
|
2179
|
+
if not self.voice_db:
|
|
2180
|
+
self.logger.warning(
|
|
2181
|
+
"Voice database not available, using default voice"
|
|
2182
|
+
)
|
|
2183
|
+
return "erinome" # Default fallback
|
|
2184
|
+
|
|
2185
|
+
try:
|
|
2186
|
+
# First, try to find voices by characteristic
|
|
2187
|
+
characteristic_voices = self.voice_db.get_voices_by_characteristic(
|
|
2188
|
+
speaker.characteristic
|
|
2189
|
+
)
|
|
2190
|
+
|
|
2191
|
+
if characteristic_voices:
|
|
2192
|
+
# Filter by gender if possible
|
|
2193
|
+
gender_filtered = [
|
|
2194
|
+
v for v in characteristic_voices if v.gender == speaker.gender
|
|
2195
|
+
]
|
|
2196
|
+
if gender_filtered:
|
|
2197
|
+
return gender_filtered[0].voice_name.lower()
|
|
2198
|
+
else:
|
|
2199
|
+
# Use first voice with matching characteristic regardless of gender
|
|
2200
|
+
return characteristic_voices[0].voice_name.lower()
|
|
2201
|
+
|
|
2202
|
+
# Fallback: find by gender only
|
|
2203
|
+
gender_voices = self.voice_db.get_voices_by_gender(speaker.gender)
|
|
2204
|
+
if gender_voices:
|
|
2205
|
+
self.logger.info(
|
|
2206
|
+
f"Found voice by gender '{speaker.gender}': {gender_voices[0].voice_name}"
|
|
2207
|
+
)
|
|
2208
|
+
return gender_voices[0].voice_name.lower()
|
|
2209
|
+
|
|
2210
|
+
# Ultimate fallback
|
|
2211
|
+
self.logger.warning(
|
|
2212
|
+
f"No voice found for speaker {speaker.name}, using default"
|
|
2213
|
+
)
|
|
2214
|
+
return "erinome"
|
|
2215
|
+
|
|
2216
|
+
except Exception as e:
|
|
2217
|
+
self.logger.error(
|
|
2218
|
+
f"Error finding voice for speaker {speaker.name}: {e}"
|
|
2219
|
+
)
|
|
2220
|
+
return "erinome"
|
|
2221
|
+
|
|
2222
|
+
async def create_conversation_script(
|
|
2223
|
+
self,
|
|
2224
|
+
report_data: ConversationalScriptConfig,
|
|
2225
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
2226
|
+
user_id: Optional[str] = None,
|
|
2227
|
+
session_id: Optional[str] = None,
|
|
2228
|
+
temperature: float = 0.7,
|
|
2229
|
+
use_structured_output: bool = False,
|
|
2230
|
+
max_lines: int = 20
|
|
2231
|
+
) -> AIMessage:
|
|
2232
|
+
"""
|
|
2233
|
+
Creates a conversation script using Google's Generative AI.
|
|
2234
|
+
Generates a fictional conversational script from a text report using a generative model.
|
|
2235
|
+
Generates a complete, TTS-ready prompt for a two-person conversation
|
|
2236
|
+
based on a source text report.
|
|
2237
|
+
|
|
2238
|
+
This method is designed to create a script that can be used with Google's TTS system.
|
|
2239
|
+
|
|
2240
|
+
Returns:
|
|
2241
|
+
A string formatted for Google's TTS `generate_content` method.
|
|
2242
|
+
Example:
|
|
2243
|
+
"Make Speaker1 sound tired and bored, and Speaker2 sound excited and happy:
|
|
2244
|
+
|
|
2245
|
+
Speaker1: So... what's on the agenda today?
|
|
2246
|
+
Speaker2: You're never going to guess!"
|
|
2247
|
+
"""
|
|
2248
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2249
|
+
self.logger.info(
|
|
2250
|
+
f"Starting Conversation Script with model: {model}"
|
|
2251
|
+
)
|
|
2252
|
+
turn_id = str(uuid.uuid4())
|
|
2253
|
+
|
|
2254
|
+
report_text = report_data.report_text
|
|
2255
|
+
if not report_text:
|
|
2256
|
+
raise ValueError(
|
|
2257
|
+
"Report text is required for generating a conversation script."
|
|
2258
|
+
)
|
|
2259
|
+
# Calculate conversation length
|
|
2260
|
+
conversation_length = min(report_data.length // 50, max_lines)
|
|
2261
|
+
if conversation_length < 4:
|
|
2262
|
+
conversation_length = max_lines
|
|
2263
|
+
system_prompt = report_data.system_prompt or "Create a natural and engaging conversation script based on the provided report."
|
|
2264
|
+
context = report_data.context or "This conversation is based on a report about a specific topic. The characters will discuss the key findings and insights from the report."
|
|
2265
|
+
interviewer = None
|
|
2266
|
+
interviewee = None
|
|
2267
|
+
for speaker in report_data.speakers:
|
|
2268
|
+
if not speaker.name or not speaker.role or not speaker.characteristic:
|
|
2269
|
+
raise ValueError(
|
|
2270
|
+
"Each speaker must have a name, role, and characteristic."
|
|
2271
|
+
)
|
|
2272
|
+
# role (interviewer or interviewee) and characteristic (e.g., friendly, professional)
|
|
2273
|
+
if speaker.role == "interviewer":
|
|
2274
|
+
interviewer = speaker
|
|
2275
|
+
elif speaker.role == "interviewee":
|
|
2276
|
+
interviewee = speaker
|
|
2277
|
+
|
|
2278
|
+
if not interviewer or not interviewee:
|
|
2279
|
+
raise ValueError("Must have exactly one interviewer and one interviewee.")
|
|
2280
|
+
system_instruction = report_data.system_instruction or f"""
|
|
2281
|
+
You are a scriptwriter. Your task is {system_prompt} for a conversation between {interviewer.name} and {interviewee.name}. "
|
|
2282
|
+
|
|
2283
|
+
**Source Report:**"
|
|
2284
|
+
---
|
|
2285
|
+
{report_text}
|
|
2286
|
+
---
|
|
2287
|
+
|
|
2288
|
+
**context:**
|
|
2289
|
+
{context}
|
|
2290
|
+
|
|
2291
|
+
|
|
2292
|
+
**Characters:**
|
|
2293
|
+
1. **{interviewer.name}**: The {interviewer.role}. Their personality is **{interviewer.characteristic}**.
|
|
2294
|
+
2. **{interviewee.name}**: The {interviewee.role}. Their personality is **{interviewee.characteristic}**.
|
|
2295
|
+
|
|
2296
|
+
**Conversation Length:** {conversation_length} lines.
|
|
2297
|
+
**Instructions:**
|
|
2298
|
+
- The conversation must be based on the key findings, data, and conclusions of the source report.
|
|
2299
|
+
- The interviewer should ask insightful questions to guide the conversation.
|
|
2300
|
+
- The interviewee should provide answers and explanations derived from the report.
|
|
2301
|
+
- The dialogue should reflect the specified personalities of the characters.
|
|
2302
|
+
- The conversation should be engaging, natural, and suitable for a TTS system.
|
|
2303
|
+
- The script should be formatted for TTS, with clear speaker lines.
|
|
2304
|
+
|
|
2305
|
+
**Gender–Neutral Output (Strict)**
|
|
2306
|
+
- Do NOT infer anyone's gender or use third-person gendered pronouns or titles: he, him, his, she, her, hers, Mr., Mrs., Ms., sir, ma’am, etc.
|
|
2307
|
+
- If a third person must be referenced, use singular they/them/their or repeat the name/role (e.g., “the manager”, “Alex”).
|
|
2308
|
+
- Do not include gendered stage directions (“in a feminine/masculine voice”).
|
|
2309
|
+
- First/second person is fine inside dialogue (“I”, “you”), but NEVER use gendered third-person forms.
|
|
2310
|
+
|
|
2311
|
+
Before finalizing, scan and fix any gendered terms. If any banned term appears, rewrite that line to comply.
|
|
2312
|
+
|
|
2313
|
+
- **IMPORTANT**: Generate ONLY the dialogue script. Do not include headers, titles, or any text other than the speaker lines. The format must be exactly:
|
|
2314
|
+
{interviewer.name}: [dialogue]
|
|
2315
|
+
{interviewee.name}: [dialogue]
|
|
2316
|
+
"""
|
|
2317
|
+
generation_config = {
|
|
2318
|
+
"max_output_tokens": self.max_tokens,
|
|
2319
|
+
"temperature": temperature or self.temperature,
|
|
2320
|
+
}
|
|
2321
|
+
|
|
2322
|
+
# Build contents for the stateless API call
|
|
2323
|
+
contents = [{
|
|
2324
|
+
"role": "user",
|
|
2325
|
+
"parts": [{"text": report_text}]
|
|
2326
|
+
}]
|
|
2327
|
+
|
|
2328
|
+
final_config = GenerateContentConfig(
|
|
2329
|
+
system_instruction=system_instruction,
|
|
2330
|
+
safety_settings=[
|
|
2331
|
+
types.SafetySetting(
|
|
2332
|
+
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
2333
|
+
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
|
2334
|
+
),
|
|
2335
|
+
types.SafetySetting(
|
|
2336
|
+
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
2337
|
+
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
|
2338
|
+
),
|
|
2339
|
+
],
|
|
2340
|
+
tools=None, # No tools needed for conversation script:
|
|
2341
|
+
**generation_config
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
# Make a stateless call to the model
|
|
2345
|
+
if not self.client:
|
|
2346
|
+
self.client = await self.get_client()
|
|
2347
|
+
# response = await self.client.aio.models.generate_content(
|
|
2348
|
+
# model=model,
|
|
2349
|
+
# contents=contents,
|
|
2350
|
+
# config=final_config
|
|
2351
|
+
# )
|
|
2352
|
+
sync_generate_content = partial(
|
|
2353
|
+
self.client.models.generate_content,
|
|
2354
|
+
model=model,
|
|
2355
|
+
contents=contents,
|
|
2356
|
+
config=final_config
|
|
2357
|
+
)
|
|
2358
|
+
# Run the synchronous function in a separate thread
|
|
2359
|
+
response = await asyncio.to_thread(sync_generate_content)
|
|
2360
|
+
# Extract the generated script text
|
|
2361
|
+
script_text = response.text if hasattr(response, 'text') else str(response)
|
|
2362
|
+
structured_output = script_text
|
|
2363
|
+
if use_structured_output:
|
|
2364
|
+
self.logger.info("Creating structured output for TTS system...")
|
|
2365
|
+
try:
|
|
2366
|
+
# Map speakers to voices
|
|
2367
|
+
speaker_configs = []
|
|
2368
|
+
for speaker in report_data.speakers:
|
|
2369
|
+
voice = self._find_voice_for_speaker(speaker)
|
|
2370
|
+
speaker_configs.append(
|
|
2371
|
+
SpeakerConfig(name=speaker.name, voice=voice)
|
|
2372
|
+
)
|
|
2373
|
+
self.logger.notice(
|
|
2374
|
+
f"Assigned voice '{voice}' to speaker '{speaker.name}'"
|
|
2375
|
+
)
|
|
2376
|
+
structured_output = SpeechGenerationPrompt(
|
|
2377
|
+
prompt=script_text,
|
|
2378
|
+
speakers=speaker_configs
|
|
2379
|
+
)
|
|
2380
|
+
except Exception as e:
|
|
2381
|
+
self.logger.error(
|
|
2382
|
+
f"Failed to create structured output: {e}"
|
|
2383
|
+
)
|
|
2384
|
+
# Continue without structured output rather than failing
|
|
2385
|
+
|
|
2386
|
+
# Create the AIMessage response using the factory
|
|
2387
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
2388
|
+
response=response,
|
|
2389
|
+
input_text=report_text,
|
|
2390
|
+
model=model,
|
|
2391
|
+
user_id=user_id,
|
|
2392
|
+
session_id=session_id,
|
|
2393
|
+
turn_id=turn_id,
|
|
2394
|
+
structured_output=structured_output,
|
|
2395
|
+
tool_calls=[]
|
|
2396
|
+
)
|
|
2397
|
+
ai_message.provider = "google_genai"
|
|
2398
|
+
|
|
2399
|
+
return ai_message
|
|
2400
|
+
|
|
2401
|
+
async def generate_speech(
|
|
2402
|
+
self,
|
|
2403
|
+
prompt_data: SpeechGenerationPrompt,
|
|
2404
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH_TTS,
|
|
2405
|
+
output_directory: Optional[Path] = None,
|
|
2406
|
+
system_prompt: Optional[str] = None,
|
|
2407
|
+
temperature: float = 0.7,
|
|
2408
|
+
mime_format: str = "audio/wav", # or "audio/mpeg", "audio/webm"
|
|
2409
|
+
user_id: Optional[str] = None,
|
|
2410
|
+
session_id: Optional[str] = None,
|
|
2411
|
+
max_retries: int = 3,
|
|
2412
|
+
retry_delay: float = 1.0
|
|
2413
|
+
) -> AIMessage:
|
|
2414
|
+
"""
|
|
2415
|
+
Generates speech from text using either a single voice or multiple voices.
|
|
2416
|
+
"""
|
|
2417
|
+
start_time = time.time()
|
|
2418
|
+
if prompt_data.model:
|
|
2419
|
+
model = prompt_data.model
|
|
2420
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2421
|
+
self.logger.info(
|
|
2422
|
+
f"Starting Speech generation with model: {model}"
|
|
2423
|
+
)
|
|
2424
|
+
|
|
2425
|
+
# Validation of voices and fallback logic before creating the SpeechConfig:
|
|
2426
|
+
valid_voices = {v.value for v in TTSVoice}
|
|
2427
|
+
processed_speakers = []
|
|
2428
|
+
for speaker in prompt_data.speakers:
|
|
2429
|
+
final_voice = speaker.voice
|
|
2430
|
+
if speaker.voice not in valid_voices:
|
|
2431
|
+
self.logger.warning(
|
|
2432
|
+
f"Invalid voice '{speaker.voice}' for speaker '{speaker.name}'. "
|
|
2433
|
+
"Using default voice instead."
|
|
2434
|
+
)
|
|
2435
|
+
gender = speaker.gender.lower() if speaker.gender else 'female'
|
|
2436
|
+
final_voice = 'zephyr' if gender == 'female' else 'charon'
|
|
2437
|
+
processed_speakers.append(
|
|
2438
|
+
SpeakerConfig(name=speaker.name, voice=final_voice, gender=speaker.gender)
|
|
2439
|
+
)
|
|
2440
|
+
|
|
2441
|
+
speech_config = None
|
|
2442
|
+
if len(processed_speakers) == 1:
|
|
2443
|
+
# Single-speaker configuration
|
|
2444
|
+
speaker = processed_speakers[0]
|
|
2445
|
+
gender = speaker.gender or 'female'
|
|
2446
|
+
default_voice = 'Charon' if gender == 'female' else 'Puck'
|
|
2447
|
+
voice = speaker.voice or default_voice
|
|
2448
|
+
self.logger.info(f"Using single voice: {voice}")
|
|
2449
|
+
speech_config = types.SpeechConfig(
|
|
2450
|
+
voice_config=types.VoiceConfig(
|
|
2451
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice)
|
|
2452
|
+
),
|
|
2453
|
+
language_code=prompt_data.language or "en-US" # Default to US English
|
|
2454
|
+
)
|
|
2455
|
+
else:
|
|
2456
|
+
# Multi-speaker configuration
|
|
2457
|
+
self.logger.info(
|
|
2458
|
+
f"Using multiple voices: {[s.voice for s in processed_speakers]}"
|
|
2459
|
+
)
|
|
2460
|
+
speaker_voice_configs = [
|
|
2461
|
+
types.SpeakerVoiceConfig(
|
|
2462
|
+
speaker=s.name,
|
|
2463
|
+
voice_config=types.VoiceConfig(
|
|
2464
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
2465
|
+
voice_name=s.voice
|
|
2466
|
+
)
|
|
2467
|
+
)
|
|
2468
|
+
) for s in processed_speakers
|
|
2469
|
+
]
|
|
2470
|
+
speech_config = types.SpeechConfig(
|
|
2471
|
+
multi_speaker_voice_config=types.MultiSpeakerVoiceConfig(
|
|
2472
|
+
speaker_voice_configs=speaker_voice_configs
|
|
2473
|
+
),
|
|
2474
|
+
language_code=prompt_data.language or "en-US" # Default to US English
|
|
2475
|
+
)
|
|
2476
|
+
|
|
2477
|
+
config = types.GenerateContentConfig(
|
|
2478
|
+
response_modalities=["AUDIO"],
|
|
2479
|
+
speech_config=speech_config,
|
|
2480
|
+
system_instruction=system_prompt,
|
|
2481
|
+
temperature=temperature
|
|
2482
|
+
)
|
|
2483
|
+
# Retry logic for network errors
|
|
2484
|
+
if not self.client:
|
|
2485
|
+
self.client = await self.get_client()
|
|
2486
|
+
# chat = self.client.aio.chats.create(model=model, history=None, config=config)
|
|
2487
|
+
for attempt in range(max_retries + 1):
|
|
2488
|
+
|
|
2489
|
+
try:
|
|
2490
|
+
if attempt > 0:
|
|
2491
|
+
delay = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
|
|
2492
|
+
self.logger.info(
|
|
2493
|
+
f"Retrying speech (attempt {attempt + 1}/{max_retries + 1}) after {delay}s delay..."
|
|
2494
|
+
)
|
|
2495
|
+
await asyncio.sleep(delay)
|
|
2496
|
+
# response = await self.client.aio.models.generate_content(
|
|
2497
|
+
# model=model,
|
|
2498
|
+
# contents=prompt_data.prompt,
|
|
2499
|
+
# config=config,
|
|
2500
|
+
# )
|
|
2501
|
+
sync_generate_content = partial(
|
|
2502
|
+
self.client.models.generate_content,
|
|
2503
|
+
model=model,
|
|
2504
|
+
contents=prompt_data.prompt,
|
|
2505
|
+
config=config
|
|
2506
|
+
)
|
|
2507
|
+
# Run the synchronous function in a separate thread
|
|
2508
|
+
response = await asyncio.to_thread(sync_generate_content)
|
|
2509
|
+
# Robust audio data extraction with proper validation
|
|
2510
|
+
audio_data = self._extract_audio_data(response)
|
|
2511
|
+
if audio_data is None:
|
|
2512
|
+
# Log the response structure for debugging
|
|
2513
|
+
self.logger.error(f"Failed to extract audio data from response")
|
|
2514
|
+
self.logger.debug(f"Response type: {type(response)}")
|
|
2515
|
+
if hasattr(response, 'candidates'):
|
|
2516
|
+
self.logger.debug(f"Candidates count: {len(response.candidates) if response.candidates else 0}")
|
|
2517
|
+
if response.candidates and len(response.candidates) > 0:
|
|
2518
|
+
candidate = response.candidates[0]
|
|
2519
|
+
self.logger.debug(f"Candidate type: {type(candidate)}")
|
|
2520
|
+
self.logger.debug(f"Candidate has content: {hasattr(candidate, 'content')}")
|
|
2521
|
+
if hasattr(candidate, 'content'):
|
|
2522
|
+
content = candidate.content
|
|
2523
|
+
self.logger.debug(f"Content is None: {content is None}")
|
|
2524
|
+
if content:
|
|
2525
|
+
self.logger.debug(f"Content has parts: {hasattr(content, 'parts')}")
|
|
2526
|
+
if hasattr(content, 'parts'):
|
|
2527
|
+
self.logger.debug(f"Parts count: {len(content.parts) if content.parts else 0}")
|
|
2528
|
+
|
|
2529
|
+
raise SpeechGenerationError(
|
|
2530
|
+
"No audio data found in response. The speech generation may have failed or "
|
|
2531
|
+
"the model may not support speech generation for this request."
|
|
2532
|
+
)
|
|
2533
|
+
|
|
2534
|
+
saved_file_paths = []
|
|
2535
|
+
|
|
2536
|
+
if output_directory:
|
|
2537
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
2538
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
2539
|
+
file_path = output_directory / f"generated_speech_{timestamp}.wav"
|
|
2540
|
+
|
|
2541
|
+
self._save_audio_file(audio_data, file_path, mime_format)
|
|
2542
|
+
saved_file_paths.append(file_path)
|
|
2543
|
+
self.logger.info(
|
|
2544
|
+
f"Saved speech to {file_path}"
|
|
2545
|
+
)
|
|
2546
|
+
|
|
2547
|
+
execution_time = time.time() - start_time
|
|
2548
|
+
usage = CompletionUsage(
|
|
2549
|
+
execution_time=execution_time,
|
|
2550
|
+
# Speech API does not return token counts
|
|
2551
|
+
input_tokens=len(prompt_data.prompt), # Approximation
|
|
2552
|
+
)
|
|
2553
|
+
|
|
2554
|
+
ai_message = AIMessageFactory.from_speech(
|
|
2555
|
+
output=audio_data, # The raw PCM audio data
|
|
2556
|
+
files=saved_file_paths,
|
|
2557
|
+
input=prompt_data.prompt,
|
|
2558
|
+
model=model,
|
|
2559
|
+
provider="google_genai",
|
|
2560
|
+
usage=usage,
|
|
2561
|
+
user_id=user_id,
|
|
2562
|
+
session_id=session_id,
|
|
2563
|
+
raw_response=None # Response object isn't easily serializable
|
|
2564
|
+
)
|
|
2565
|
+
return ai_message
|
|
2566
|
+
|
|
2567
|
+
except (
|
|
2568
|
+
aiohttp.ClientPayloadError,
|
|
2569
|
+
aiohttp.ClientConnectionError,
|
|
2570
|
+
aiohttp.ClientResponseError,
|
|
2571
|
+
aiohttp.ServerTimeoutError,
|
|
2572
|
+
ConnectionResetError,
|
|
2573
|
+
TimeoutError,
|
|
2574
|
+
asyncio.TimeoutError
|
|
2575
|
+
) as network_error:
|
|
2576
|
+
error_msg = str(network_error)
|
|
2577
|
+
|
|
2578
|
+
# Specific handling for different network errors
|
|
2579
|
+
if "TransferEncodingError" in error_msg:
|
|
2580
|
+
self.logger.warning(
|
|
2581
|
+
f"Transfer encoding error on attempt {attempt + 1}: {error_msg}")
|
|
2582
|
+
elif "Connection reset by peer" in error_msg:
|
|
2583
|
+
self.logger.warning(
|
|
2584
|
+
f"Connection reset on attempt {attempt + 1}: Server closed connection")
|
|
2585
|
+
elif "timeout" in error_msg.lower():
|
|
2586
|
+
self.logger.warning(
|
|
2587
|
+
f"Timeout error on attempt {attempt + 1}: {error_msg}")
|
|
2588
|
+
else:
|
|
2589
|
+
self.logger.warning(
|
|
2590
|
+
f"Network error on attempt {attempt + 1}: {error_msg}"
|
|
2591
|
+
)
|
|
2592
|
+
|
|
2593
|
+
if attempt < max_retries:
|
|
2594
|
+
self.logger.debug(
|
|
2595
|
+
f"Will retry in {retry_delay * (2 ** attempt)}s..."
|
|
2596
|
+
)
|
|
2597
|
+
continue
|
|
2598
|
+
else:
|
|
2599
|
+
# Max retries exceeded
|
|
2600
|
+
self.logger.error(
|
|
2601
|
+
f"Speech generation failed after {max_retries + 1} attempts"
|
|
2602
|
+
)
|
|
2603
|
+
raise SpeechGenerationError(
|
|
2604
|
+
f"Speech generation failed after {max_retries + 1} attempts. "
|
|
2605
|
+
f"Last error: {error_msg}. This is typically a temporary network issue - please try again."
|
|
2606
|
+
) from network_error
|
|
2607
|
+
|
|
2608
|
+
except Exception as e:
|
|
2609
|
+
# Non-network errors - don't retry
|
|
2610
|
+
error_msg = str(e)
|
|
2611
|
+
self.logger.error(
|
|
2612
|
+
f"Speech generation failed with non-retryable error: {error_msg}"
|
|
2613
|
+
)
|
|
2614
|
+
|
|
2615
|
+
# Provide helpful error messages based on error type
|
|
2616
|
+
if "quota" in error_msg.lower() or "rate limit" in error_msg.lower():
|
|
2617
|
+
raise SpeechGenerationError(
|
|
2618
|
+
f"API quota or rate limit exceeded: {error_msg}. Please try again later."
|
|
2619
|
+
) from e
|
|
2620
|
+
elif "permission" in error_msg.lower() or "unauthorized" in error_msg.lower():
|
|
2621
|
+
raise SpeechGenerationError(
|
|
2622
|
+
f"Authorization error: {error_msg}. Please check your API credentials."
|
|
2623
|
+
) from e
|
|
2624
|
+
elif "model" in error_msg.lower():
|
|
2625
|
+
raise SpeechGenerationError(
|
|
2626
|
+
f"Model error: {error_msg}. The model '{model}' may not support speech generation."
|
|
2627
|
+
) from e
|
|
2628
|
+
else:
|
|
2629
|
+
raise SpeechGenerationError(
|
|
2630
|
+
f"Speech generation failed: {error_msg}"
|
|
2631
|
+
) from e
|
|
2632
|
+
|
|
2633
|
+
def _extract_audio_data(self, response):
|
|
2634
|
+
"""
|
|
2635
|
+
Robustly extract audio data from Google GenAI response.
|
|
2636
|
+
Similar to the text extraction pattern used elsewhere in the codebase.
|
|
2637
|
+
"""
|
|
2638
|
+
try:
|
|
2639
|
+
# First attempt: Direct access to expected structure
|
|
2640
|
+
if (hasattr(response, 'candidates') and
|
|
2641
|
+
response.candidates and
|
|
2642
|
+
len(response.candidates) > 0 and
|
|
2643
|
+
hasattr(response.candidates[0], 'content') and
|
|
2644
|
+
response.candidates[0].content and
|
|
2645
|
+
hasattr(response.candidates[0].content, 'parts') and
|
|
2646
|
+
response.candidates[0].content.parts and
|
|
2647
|
+
len(response.candidates[0].content.parts) > 0):
|
|
2648
|
+
|
|
2649
|
+
for part in response.candidates[0].content.parts:
|
|
2650
|
+
# Check for inline_data with audio data
|
|
2651
|
+
if (hasattr(part, 'inline_data') and
|
|
2652
|
+
part.inline_data and
|
|
2653
|
+
hasattr(part.inline_data, 'data') and
|
|
2654
|
+
part.inline_data.data):
|
|
2655
|
+
self.logger.debug("Found audio data in inline_data.data")
|
|
2656
|
+
return part.inline_data.data
|
|
2657
|
+
|
|
2658
|
+
# Alternative: Check for direct data attribute
|
|
2659
|
+
if hasattr(part, 'data') and part.data:
|
|
2660
|
+
self.logger.debug("Found audio data in part.data")
|
|
2661
|
+
return part.data
|
|
2662
|
+
|
|
2663
|
+
# Alternative: Check for binary data
|
|
2664
|
+
if hasattr(part, 'binary') and part.binary:
|
|
2665
|
+
self.logger.debug("Found audio data in part.binary")
|
|
2666
|
+
return part.binary
|
|
2667
|
+
|
|
2668
|
+
self.logger.warning("No audio data found in expected response structure")
|
|
2669
|
+
return None
|
|
2670
|
+
|
|
2671
|
+
except Exception as e:
|
|
2672
|
+
self.logger.error(f"Audio data extraction failed: {e}")
|
|
2673
|
+
return None
|
|
2674
|
+
|
|
2675
|
+
async def generate_videos(
|
|
2676
|
+
self,
|
|
2677
|
+
prompt: VideoGenerationPrompt,
|
|
2678
|
+
reference_image: Optional[Path] = None,
|
|
2679
|
+
output_directory: Optional[Path] = None,
|
|
2680
|
+
mime_format: str = "video/mp4",
|
|
2681
|
+
model: Union[str, GoogleModel] = GoogleModel.VEO_3_0,
|
|
2682
|
+
) -> AIMessage:
|
|
2683
|
+
"""
|
|
2684
|
+
Generate a video using the specified model and prompt.
|
|
2685
|
+
"""
|
|
2686
|
+
if prompt.model:
|
|
2687
|
+
model = prompt.model
|
|
2688
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2689
|
+
if model not in [GoogleModel.VEO_2_0.value, GoogleModel.VEO_3_0.value]:
|
|
2690
|
+
raise ValueError(
|
|
2691
|
+
"Generate Videos are only supported with VEO 2.0 or VEO 3.0 models."
|
|
2692
|
+
)
|
|
2693
|
+
self.logger.info(
|
|
2694
|
+
f"Starting Video generation with model: {model}"
|
|
2695
|
+
)
|
|
2696
|
+
if output_directory:
|
|
2697
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
2698
|
+
else:
|
|
2699
|
+
output_directory = BASE_DIR.joinpath('static', 'generated_videos')
|
|
2700
|
+
args = {
|
|
2701
|
+
"prompt": prompt.prompt,
|
|
2702
|
+
"model": model,
|
|
2703
|
+
}
|
|
2704
|
+
|
|
2705
|
+
if reference_image:
|
|
2706
|
+
# if a reference image is used, only Veo2 is supported:
|
|
2707
|
+
self.logger.info(
|
|
2708
|
+
f"Veo 3.0 does not support reference images, using VEO 2.0 instead."
|
|
2709
|
+
)
|
|
2710
|
+
model = GoogleModel.VEO_2_0.value
|
|
2711
|
+
self.logger.info(
|
|
2712
|
+
f"Using reference image: {reference_image}"
|
|
2713
|
+
)
|
|
2714
|
+
if not reference_image.exists():
|
|
2715
|
+
raise FileNotFoundError(
|
|
2716
|
+
f"Reference image not found: {reference_image}"
|
|
2717
|
+
)
|
|
2718
|
+
# Load the reference image
|
|
2719
|
+
ref_image = Image.open(reference_image)
|
|
2720
|
+
args['image'] = types.Image(image_bytes=ref_image)
|
|
2721
|
+
|
|
2722
|
+
start_time = time.time()
|
|
2723
|
+
operation = self.client.models.generate_videos(
|
|
2724
|
+
**args,
|
|
2725
|
+
config=types.GenerateVideosConfig(
|
|
2726
|
+
aspect_ratio=prompt.aspect_ratio or "16:9", # Default to 16:9
|
|
2727
|
+
negative_prompt=prompt.negative_prompt, # Optional negative prompt
|
|
2728
|
+
number_of_videos=prompt.number_of_videos, # Number of videos to generate
|
|
2729
|
+
)
|
|
2730
|
+
)
|
|
2731
|
+
|
|
2732
|
+
print("Video generation job started. Waiting for completion...", end="")
|
|
2733
|
+
spinner_chars = ['|', '/', '-', '\\']
|
|
2734
|
+
check_interval = 10 # Check status every 10 seconds
|
|
2735
|
+
spinner_index = 0
|
|
2736
|
+
|
|
2737
|
+
# This loop checks the job status every 10 seconds
|
|
2738
|
+
while not operation.done:
|
|
2739
|
+
# This inner loop runs the spinner animation for the check_interval
|
|
2740
|
+
for _ in range(check_interval):
|
|
2741
|
+
# Write the spinner character to the console
|
|
2742
|
+
sys.stdout.write(
|
|
2743
|
+
f"\rVideo generation job started. Waiting for completion... {spinner_chars[spinner_index]}"
|
|
2744
|
+
)
|
|
2745
|
+
sys.stdout.flush()
|
|
2746
|
+
spinner_index = (spinner_index + 1) % len(spinner_chars)
|
|
2747
|
+
time.sleep(1) # Animate every second
|
|
2748
|
+
|
|
2749
|
+
# After 10 seconds, get the updated operation status
|
|
2750
|
+
operation = self.client.operations.get(operation)
|
|
2751
|
+
|
|
2752
|
+
print("\rVideo generation job completed. ", end="")
|
|
2753
|
+
|
|
2754
|
+
for n, generated_video in enumerate(operation.result.generated_videos):
|
|
2755
|
+
# Download the generated videos
|
|
2756
|
+
# bytes of the original MP4
|
|
2757
|
+
mp4_bytes = self.client.files.download(file=generated_video.video)
|
|
2758
|
+
video_path = self._save_video_file(
|
|
2759
|
+
mp4_bytes,
|
|
2760
|
+
output_directory,
|
|
2761
|
+
video_number=n,
|
|
2762
|
+
mime_format=mime_format
|
|
2763
|
+
)
|
|
2764
|
+
execution_time = time.time() - start_time
|
|
2765
|
+
usage = CompletionUsage(
|
|
2766
|
+
execution_time=execution_time,
|
|
2767
|
+
# Video API does not return token counts
|
|
2768
|
+
input_tokens=len(prompt.prompt), # Approximation
|
|
2769
|
+
)
|
|
2770
|
+
|
|
2771
|
+
ai_message = AIMessageFactory.from_video(
|
|
2772
|
+
output=operation, # The raw Video object
|
|
2773
|
+
files=[video_path],
|
|
2774
|
+
input=prompt.prompt,
|
|
2775
|
+
model=model,
|
|
2776
|
+
provider="google_genai",
|
|
2777
|
+
usage=usage,
|
|
2778
|
+
user_id=None,
|
|
2779
|
+
session_id=None,
|
|
2780
|
+
raw_response=None # Response object isn't easily serializable
|
|
2781
|
+
)
|
|
2782
|
+
return ai_message
|
|
2783
|
+
|
|
2784
|
+
async def _deep_research_ask(
|
|
2785
|
+
self,
|
|
2786
|
+
prompt: str,
|
|
2787
|
+
background: bool = False,
|
|
2788
|
+
file_search_store_names: Optional[List[str]] = None,
|
|
2789
|
+
user_id: Optional[str] = None,
|
|
2790
|
+
session_id: Optional[str] = None
|
|
2791
|
+
) -> AIMessage:
|
|
2792
|
+
"""
|
|
2793
|
+
Perform deep research using Google's interactions.create() API.
|
|
2794
|
+
|
|
2795
|
+
Note: This is a stub implementation. Full implementation requires the
|
|
2796
|
+
Google Gen AI interactions SDK which uses a different API than the
|
|
2797
|
+
standard models.generate_content().
|
|
2798
|
+
"""
|
|
2799
|
+
self.logger.warning(
|
|
2800
|
+
"Google Deep Research is not yet fully implemented. "
|
|
2801
|
+
"This feature requires the interactions API which is currently in preview. "
|
|
2802
|
+
"Falling back to standard ask() behavior for now."
|
|
2803
|
+
)
|
|
2804
|
+
# TODO: Implement using client.interactions.create() when SDK supports it
|
|
2805
|
+
# For now, fall back to regular ask without deep_research flag
|
|
2806
|
+
return await self.ask(
|
|
2807
|
+
prompt=prompt,
|
|
2808
|
+
user_id=user_id,
|
|
2809
|
+
session_id=session_id,
|
|
2810
|
+
deep_research=False # Prevent infinite recursion
|
|
2811
|
+
)
|
|
2812
|
+
|
|
2813
|
+
async def question(
|
|
2814
|
+
self,
|
|
2815
|
+
prompt: str,
|
|
2816
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
2817
|
+
max_tokens: Optional[int] = None,
|
|
2818
|
+
temperature: Optional[float] = None,
|
|
2819
|
+
files: Optional[List[Union[str, Path]]] = None,
|
|
2820
|
+
user_id: Optional[str] = None,
|
|
2821
|
+
session_id: Optional[str] = None,
|
|
2822
|
+
system_prompt: Optional[str] = None,
|
|
2823
|
+
structured_output: Union[type, StructuredOutputConfig] = None,
|
|
2824
|
+
use_internal_tools: bool = False, # New parameter to control internal tools
|
|
2825
|
+
) -> AIMessage:
|
|
2826
|
+
"""
|
|
2827
|
+
Ask a question to Google's Generative AI in a stateless manner,
|
|
2828
|
+
without conversation history and with optional internal tools.
|
|
2829
|
+
|
|
2830
|
+
Args:
|
|
2831
|
+
prompt (str): The input prompt for the model.
|
|
2832
|
+
model (Union[str, GoogleModel]): The model to use, defaults to GEMINI_2_5_FLASH.
|
|
2833
|
+
max_tokens (int): Maximum number of tokens in the response.
|
|
2834
|
+
temperature (float): Sampling temperature for response generation.
|
|
2835
|
+
files (Optional[List[Union[str, Path]]]): Optional files to include in the request.
|
|
2836
|
+
system_prompt (Optional[str]): Optional system prompt to guide the model.
|
|
2837
|
+
structured_output (Union[type, StructuredOutputConfig]): Optional structured output configuration.
|
|
2838
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
2839
|
+
session_id (Optional[str]): Optional session identifier for tracking.
|
|
2840
|
+
use_internal_tools (bool): If True, Gemini's built-in tools (e.g., Google Search)
|
|
2841
|
+
will be made available to the model. Defaults to False.
|
|
2842
|
+
"""
|
|
2843
|
+
self.logger.info(
|
|
2844
|
+
f"Initiating RAG pipeline for prompt: '{prompt[:50]}...'"
|
|
2845
|
+
)
|
|
2846
|
+
|
|
2847
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2848
|
+
turn_id = str(uuid.uuid4())
|
|
2849
|
+
original_prompt = prompt
|
|
2850
|
+
|
|
2851
|
+
output_config = self._get_structured_config(structured_output)
|
|
2852
|
+
|
|
2853
|
+
generation_config = {
|
|
2854
|
+
"max_output_tokens": max_tokens or self.max_tokens,
|
|
2855
|
+
"temperature": temperature or self.temperature,
|
|
2856
|
+
}
|
|
2857
|
+
|
|
2858
|
+
if output_config:
|
|
2859
|
+
self._apply_structured_output_schema(generation_config, output_config)
|
|
2860
|
+
|
|
2861
|
+
tools = None
|
|
2862
|
+
if use_internal_tools:
|
|
2863
|
+
tools = self._build_tools("builtin_tools") # Only built-in tools
|
|
2864
|
+
self.logger.debug(
|
|
2865
|
+
f"Enabled internal tool usage."
|
|
2866
|
+
)
|
|
2867
|
+
|
|
2868
|
+
# Build contents for the stateless call
|
|
2869
|
+
contents = []
|
|
2870
|
+
if files:
|
|
2871
|
+
for file_path in files:
|
|
2872
|
+
# In a real scenario, you'd handle file uploads to Gemini properly
|
|
2873
|
+
# This is a placeholder for file content
|
|
2874
|
+
contents.append(
|
|
2875
|
+
{
|
|
2876
|
+
"part": {
|
|
2877
|
+
"inline_data": {
|
|
2878
|
+
"mime_type": "application/octet-stream",
|
|
2879
|
+
"data": "BASE64_ENCODED_FILE_CONTENT"
|
|
2880
|
+
}
|
|
2881
|
+
}
|
|
2882
|
+
}
|
|
2883
|
+
)
|
|
2884
|
+
|
|
2885
|
+
# Add the user prompt as the first part
|
|
2886
|
+
contents.append({
|
|
2887
|
+
"role": "user",
|
|
2888
|
+
"parts": [{"text": prompt}]
|
|
2889
|
+
})
|
|
2890
|
+
|
|
2891
|
+
all_tool_calls = [] # To capture any tool calls made by internal tools
|
|
2892
|
+
|
|
2893
|
+
final_config = GenerateContentConfig(
|
|
2894
|
+
system_instruction=system_prompt,
|
|
2895
|
+
tools=tools,
|
|
2896
|
+
**generation_config
|
|
2897
|
+
)
|
|
2898
|
+
|
|
2899
|
+
response = await self.client.aio.models.generate_content(
|
|
2900
|
+
model=model,
|
|
2901
|
+
contents=contents,
|
|
2902
|
+
config=final_config
|
|
2903
|
+
)
|
|
2904
|
+
|
|
2905
|
+
# Handle potential internal tool calls if they are part of the direct generate_content response
|
|
2906
|
+
# Gemini can sometimes decide to use internal tools even without explicit function calling setup
|
|
2907
|
+
# if the tools are broadly enabled (e.g., through a general 'tool' parameter).
|
|
2908
|
+
# This part assumes Gemini's 'generate_content' directly returns tool calls if it uses them.
|
|
2909
|
+
if use_internal_tools and response.candidates and response.candidates[0].content.parts:
|
|
2910
|
+
function_calls = [
|
|
2911
|
+
part.function_call
|
|
2912
|
+
for part in response.candidates[0].content.parts
|
|
2913
|
+
if hasattr(part, 'function_call') and part.function_call
|
|
2914
|
+
]
|
|
2915
|
+
if function_calls:
|
|
2916
|
+
tool_call_objects = []
|
|
2917
|
+
for fc in function_calls:
|
|
2918
|
+
tc = ToolCall(
|
|
2919
|
+
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
2920
|
+
name=fc.name,
|
|
2921
|
+
arguments=dict(fc.args)
|
|
2922
|
+
)
|
|
2923
|
+
tool_call_objects.append(tc)
|
|
2924
|
+
|
|
2925
|
+
start_time = time.time()
|
|
2926
|
+
tool_execution_tasks = [
|
|
2927
|
+
self._execute_tool(fc.name, dict(fc.args)) for fc in function_calls
|
|
2928
|
+
]
|
|
2929
|
+
tool_results = await asyncio.gather(
|
|
2930
|
+
*tool_execution_tasks,
|
|
2931
|
+
return_exceptions=True
|
|
2932
|
+
)
|
|
2933
|
+
execution_time = time.time() - start_time
|
|
2934
|
+
|
|
2935
|
+
for tc, result in zip(tool_call_objects, tool_results):
|
|
2936
|
+
tc.execution_time = execution_time / len(tool_call_objects)
|
|
2937
|
+
if isinstance(result, Exception):
|
|
2938
|
+
tc.error = str(result)
|
|
2939
|
+
else:
|
|
2940
|
+
tc.result = result
|
|
2941
|
+
|
|
2942
|
+
all_tool_calls.extend(tool_call_objects)
|
|
2943
|
+
pass # We're not doing a multi-turn here for stateless
|
|
2944
|
+
|
|
2945
|
+
final_output = None
|
|
2946
|
+
if output_config:
|
|
2947
|
+
try:
|
|
2948
|
+
final_output = await self._parse_structured_output(
|
|
2949
|
+
response.text,
|
|
2950
|
+
output_config
|
|
2951
|
+
)
|
|
2952
|
+
except Exception:
|
|
2953
|
+
final_output = response.text
|
|
2954
|
+
|
|
2955
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
2956
|
+
response=response,
|
|
2957
|
+
input_text=original_prompt,
|
|
2958
|
+
model=model,
|
|
2959
|
+
user_id=user_id,
|
|
2960
|
+
session_id=session_id,
|
|
2961
|
+
turn_id=turn_id,
|
|
2962
|
+
structured_output=final_output if final_output != response.text else None,
|
|
2963
|
+
tool_calls=all_tool_calls
|
|
2964
|
+
)
|
|
2965
|
+
ai_message.provider = "google_genai"
|
|
2966
|
+
|
|
2967
|
+
return ai_message
|
|
2968
|
+
|
|
2969
|
+
async def summarize_text(
|
|
2970
|
+
self,
|
|
2971
|
+
text: str,
|
|
2972
|
+
max_length: int = 500,
|
|
2973
|
+
min_length: int = 100,
|
|
2974
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
2975
|
+
temperature: Optional[float] = None,
|
|
2976
|
+
user_id: Optional[str] = None,
|
|
2977
|
+
session_id: Optional[str] = None,
|
|
2978
|
+
) -> AIMessage:
|
|
2979
|
+
"""
|
|
2980
|
+
Generates a summary for a given text in a stateless manner.
|
|
2981
|
+
|
|
2982
|
+
Args:
|
|
2983
|
+
text (str): The text content to summarize.
|
|
2984
|
+
max_length (int): The maximum desired character length for the summary.
|
|
2985
|
+
min_length (int): The minimum desired character length for the summary.
|
|
2986
|
+
model (Union[str, GoogleModel]): The model to use.
|
|
2987
|
+
temperature (float): Sampling temperature for response generation.
|
|
2988
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
2989
|
+
session_id (Optional[str]): Optional session identifier for tracking.
|
|
2990
|
+
"""
|
|
2991
|
+
self.logger.info(
|
|
2992
|
+
f"Generating summary for text: '{text[:50]}...'"
|
|
2993
|
+
)
|
|
2994
|
+
|
|
2995
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
2996
|
+
turn_id = str(uuid.uuid4())
|
|
2997
|
+
|
|
2998
|
+
# Define the specific system prompt for summarization
|
|
2999
|
+
system_prompt = f"""
|
|
3000
|
+
Your job is to produce a final summary from the following text and identify the main theme.
|
|
3001
|
+
- The summary should be concise and to the point.
|
|
3002
|
+
- The summary should be no longer than {max_length} characters and no less than {min_length} characters.
|
|
3003
|
+
- The summary should be in a single paragraph.
|
|
3004
|
+
"""
|
|
3005
|
+
|
|
3006
|
+
generation_config = {
|
|
3007
|
+
"max_output_tokens": self.max_tokens,
|
|
3008
|
+
"temperature": temperature or self.temperature,
|
|
3009
|
+
}
|
|
3010
|
+
|
|
3011
|
+
# Build contents for the stateless call. The 'prompt' is the text to be summarized.
|
|
3012
|
+
contents = [{
|
|
3013
|
+
"role": "user",
|
|
3014
|
+
"parts": [{"text": text}]
|
|
3015
|
+
}]
|
|
3016
|
+
|
|
3017
|
+
final_config = GenerateContentConfig(
|
|
3018
|
+
system_instruction=system_prompt,
|
|
3019
|
+
tools=None, # No tools needed for summarization
|
|
3020
|
+
**generation_config
|
|
3021
|
+
)
|
|
3022
|
+
|
|
3023
|
+
# Make a stateless call to the model
|
|
3024
|
+
response = await self.client.aio.models.generate_content(
|
|
3025
|
+
model=model,
|
|
3026
|
+
contents=contents,
|
|
3027
|
+
config=final_config
|
|
3028
|
+
)
|
|
3029
|
+
|
|
3030
|
+
# Create the AIMessage response using the factory
|
|
3031
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3032
|
+
response=response,
|
|
3033
|
+
input_text=text,
|
|
3034
|
+
model=model,
|
|
3035
|
+
user_id=user_id,
|
|
3036
|
+
session_id=session_id,
|
|
3037
|
+
turn_id=turn_id,
|
|
3038
|
+
structured_output=None,
|
|
3039
|
+
tool_calls=[]
|
|
3040
|
+
)
|
|
3041
|
+
ai_message.provider = "google_genai"
|
|
3042
|
+
|
|
3043
|
+
return ai_message
|
|
3044
|
+
|
|
3045
|
+
async def translate_text(
|
|
3046
|
+
self,
|
|
3047
|
+
text: str,
|
|
3048
|
+
target_lang: str,
|
|
3049
|
+
source_lang: Optional[str] = None,
|
|
3050
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
3051
|
+
temperature: Optional[float] = 0.2,
|
|
3052
|
+
user_id: Optional[str] = None,
|
|
3053
|
+
session_id: Optional[str] = None,
|
|
3054
|
+
) -> AIMessage:
|
|
3055
|
+
"""
|
|
3056
|
+
Translates a given text from a source language to a target language.
|
|
3057
|
+
|
|
3058
|
+
Args:
|
|
3059
|
+
text (str): The text content to translate.
|
|
3060
|
+
target_lang (str): The ISO code for the target language (e.g., 'es', 'fr').
|
|
3061
|
+
source_lang (Optional[str]): The ISO code for the source language.
|
|
3062
|
+
If None, the model will attempt to detect it.
|
|
3063
|
+
model (Union[str, GoogleModel]): The model to use. Defaults to GEMINI_2_5_FLASH,
|
|
3064
|
+
which is recommended for speed.
|
|
3065
|
+
temperature (float): Sampling temperature for response generation.
|
|
3066
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
3067
|
+
session_id (Optional[str]): Optional session identifier for tracking.
|
|
3068
|
+
"""
|
|
3069
|
+
self.logger.info(
|
|
3070
|
+
f"Translating text to '{target_lang}': '{text[:50]}...'"
|
|
3071
|
+
)
|
|
3072
|
+
|
|
3073
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
3074
|
+
turn_id = str(uuid.uuid4())
|
|
3075
|
+
|
|
3076
|
+
# Construct the system prompt for translation
|
|
3077
|
+
if source_lang:
|
|
3078
|
+
prompt_instruction = (
|
|
3079
|
+
f"Translate the following text from {source_lang} to {target_lang}. "
|
|
3080
|
+
"Only return the translated text, without any additional comments or explanations."
|
|
3081
|
+
)
|
|
3082
|
+
else:
|
|
3083
|
+
prompt_instruction = (
|
|
3084
|
+
f"First, detect the source language of the following text. Then, translate it to {target_lang}. "
|
|
3085
|
+
"Only return the translated text, without any additional comments or explanations."
|
|
3086
|
+
)
|
|
3087
|
+
|
|
3088
|
+
generation_config = {
|
|
3089
|
+
"max_output_tokens": self.max_tokens,
|
|
3090
|
+
"temperature": temperature or self.temperature,
|
|
3091
|
+
}
|
|
3092
|
+
|
|
3093
|
+
# Build contents for the stateless API call
|
|
3094
|
+
contents = [{
|
|
3095
|
+
"role": "user",
|
|
3096
|
+
"parts": [{"text": text}]
|
|
3097
|
+
}]
|
|
3098
|
+
|
|
3099
|
+
final_config = GenerateContentConfig(
|
|
3100
|
+
system_instruction=prompt_instruction,
|
|
3101
|
+
tools=None, # No tools needed for translation
|
|
3102
|
+
**generation_config
|
|
3103
|
+
)
|
|
3104
|
+
|
|
3105
|
+
# Make a stateless call to the model
|
|
3106
|
+
response = await self.client.aio.models.generate_content(
|
|
3107
|
+
model=model,
|
|
3108
|
+
contents=contents,
|
|
3109
|
+
config=final_config
|
|
3110
|
+
)
|
|
3111
|
+
|
|
3112
|
+
# Create the AIMessage response using the factory
|
|
3113
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3114
|
+
response=response,
|
|
3115
|
+
input_text=text,
|
|
3116
|
+
model=model,
|
|
3117
|
+
user_id=user_id,
|
|
3118
|
+
session_id=session_id,
|
|
3119
|
+
turn_id=turn_id,
|
|
3120
|
+
structured_output=None,
|
|
3121
|
+
tool_calls=[]
|
|
3122
|
+
)
|
|
3123
|
+
ai_message.provider = "google_genai"
|
|
3124
|
+
|
|
3125
|
+
return ai_message
|
|
3126
|
+
|
|
3127
|
+
async def extract_key_points(
|
|
3128
|
+
self,
|
|
3129
|
+
text: str,
|
|
3130
|
+
num_points: int = 5,
|
|
3131
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH, # Changed to GoogleModel
|
|
3132
|
+
temperature: Optional[float] = 0.3,
|
|
3133
|
+
user_id: Optional[str] = None,
|
|
3134
|
+
session_id: Optional[str] = None,
|
|
3135
|
+
) -> AIMessage:
|
|
3136
|
+
"""
|
|
3137
|
+
Extract *num_points* bullet-point key ideas from *text* (stateless).
|
|
3138
|
+
"""
|
|
3139
|
+
self.logger.info(
|
|
3140
|
+
f"Extracting {num_points} key points from text: '{text[:50]}...'"
|
|
3141
|
+
)
|
|
3142
|
+
|
|
3143
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
3144
|
+
turn_id = str(uuid.uuid4())
|
|
3145
|
+
|
|
3146
|
+
system_instruction = ( # Changed to system_instruction for Google GenAI
|
|
3147
|
+
f"Extract the {num_points} most important key points from the following text.\n"
|
|
3148
|
+
"- Present each point as a clear, concise bullet point (•).\n"
|
|
3149
|
+
"- Focus on the main ideas and significant information.\n"
|
|
3150
|
+
"- Each point should be self-contained and meaningful.\n"
|
|
3151
|
+
"- Order points by importance (most important first)."
|
|
3152
|
+
)
|
|
3153
|
+
|
|
3154
|
+
# Build contents for the stateless API call
|
|
3155
|
+
contents = [{
|
|
3156
|
+
"role": "user",
|
|
3157
|
+
"parts": [{"text": text}]
|
|
3158
|
+
}]
|
|
3159
|
+
|
|
3160
|
+
generation_config = {
|
|
3161
|
+
"max_output_tokens": self.max_tokens,
|
|
3162
|
+
"temperature": temperature or self.temperature,
|
|
3163
|
+
}
|
|
3164
|
+
|
|
3165
|
+
final_config = GenerateContentConfig(
|
|
3166
|
+
system_instruction=system_instruction,
|
|
3167
|
+
tools=None, # No tools needed for this task
|
|
3168
|
+
**generation_config
|
|
3169
|
+
)
|
|
3170
|
+
|
|
3171
|
+
# Make a stateless call to the model
|
|
3172
|
+
response = await self.client.aio.models.generate_content(
|
|
3173
|
+
model=model,
|
|
3174
|
+
contents=contents,
|
|
3175
|
+
config=final_config
|
|
3176
|
+
)
|
|
3177
|
+
|
|
3178
|
+
# Create the AIMessage response using the factory
|
|
3179
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3180
|
+
response=response,
|
|
3181
|
+
input_text=text,
|
|
3182
|
+
model=model,
|
|
3183
|
+
user_id=user_id,
|
|
3184
|
+
session_id=session_id,
|
|
3185
|
+
turn_id=turn_id,
|
|
3186
|
+
structured_output=None, # No structured output explicitly requested
|
|
3187
|
+
tool_calls=[] # No tool calls for this method
|
|
3188
|
+
)
|
|
3189
|
+
ai_message.provider = "google_genai" # Set provider
|
|
3190
|
+
|
|
3191
|
+
return ai_message
|
|
3192
|
+
|
|
3193
|
+
async def analyze_sentiment(
|
|
3194
|
+
self,
|
|
3195
|
+
text: str,
|
|
3196
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
3197
|
+
temperature: Optional[float] = 0.1,
|
|
3198
|
+
user_id: Optional[str] = None,
|
|
3199
|
+
session_id: Optional[str] = None,
|
|
3200
|
+
use_structured: bool = False,
|
|
3201
|
+
) -> AIMessage:
|
|
3202
|
+
"""
|
|
3203
|
+
Perform sentiment analysis on text and return a structured or unstructured response.
|
|
3204
|
+
|
|
3205
|
+
Args:
|
|
3206
|
+
text (str): The text to analyze.
|
|
3207
|
+
model (Union[GoogleModel, str]): The model to use for the analysis.
|
|
3208
|
+
temperature (float): Sampling temperature for response generation.
|
|
3209
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
3210
|
+
session_id (Optional[str]): Optional session identifier for tracking.
|
|
3211
|
+
use_structured (bool): If True, forces a structured JSON output matching
|
|
3212
|
+
the SentimentAnalysis model. Defaults to False.
|
|
3213
|
+
"""
|
|
3214
|
+
self.logger.info(f"Analyzing sentiment for text: '{text[:50]}...'")
|
|
3215
|
+
|
|
3216
|
+
model_name = model.value if isinstance(model, GoogleModel) else model
|
|
3217
|
+
turn_id = str(uuid.uuid4())
|
|
3218
|
+
|
|
3219
|
+
system_instruction = ""
|
|
3220
|
+
generation_config = {
|
|
3221
|
+
"max_output_tokens": self.max_tokens,
|
|
3222
|
+
"temperature": temperature or self.temperature,
|
|
3223
|
+
}
|
|
3224
|
+
structured_output_model = None
|
|
3225
|
+
|
|
3226
|
+
if use_structured:
|
|
3227
|
+
# ✍️ Generate a prompt to force JSON output matching the Pydantic schema
|
|
3228
|
+
schema = SentimentAnalysis.model_json_schema()
|
|
3229
|
+
system_instruction = (
|
|
3230
|
+
"You are an expert in sentiment analysis. Analyze the following text and provide a structured JSON response. "
|
|
3231
|
+
"Your response MUST be a valid JSON object that conforms to the following JSON Schema. "
|
|
3232
|
+
"Do not include any other text, explanations, or markdown formatting like ```json ... ```.\n\n"
|
|
3233
|
+
f"JSON Schema:\n{self._json.dumps(schema, indent=2)}"
|
|
3234
|
+
)
|
|
3235
|
+
# Enable Gemini's JSON mode for reliable structured output
|
|
3236
|
+
generation_config["response_mime_type"] = "application/json"
|
|
3237
|
+
structured_output_model = SentimentAnalysis
|
|
3238
|
+
else:
|
|
3239
|
+
# The original prompt for a human-readable, unstructured response
|
|
3240
|
+
system_instruction = (
|
|
3241
|
+
"Analyze the sentiment of the following text and provide a structured response.\n"
|
|
3242
|
+
"Your response must include:\n"
|
|
3243
|
+
"1. Overall sentiment (Positive, Negative, Neutral, or Mixed)\n"
|
|
3244
|
+
"2. Confidence level (High, Medium, Low)\n"
|
|
3245
|
+
"3. Key emotional indicators found in the text\n"
|
|
3246
|
+
"4. Brief explanation of your analysis\n\n"
|
|
3247
|
+
"Format your answer clearly with numbered sections."
|
|
3248
|
+
)
|
|
3249
|
+
|
|
3250
|
+
contents = [{"role": "user", "parts": [{"text": text}]}]
|
|
3251
|
+
|
|
3252
|
+
final_config = GenerateContentConfig(
|
|
3253
|
+
system_instruction={"role": "system", "parts": [{"text": system_instruction}]},
|
|
3254
|
+
tools=None,
|
|
3255
|
+
**generation_config,
|
|
3256
|
+
)
|
|
3257
|
+
|
|
3258
|
+
response = await self.client.aio.models.generate_content(
|
|
3259
|
+
model=model_name,
|
|
3260
|
+
contents=contents,
|
|
3261
|
+
config=final_config,
|
|
3262
|
+
)
|
|
3263
|
+
|
|
3264
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3265
|
+
response=response,
|
|
3266
|
+
input_text=text,
|
|
3267
|
+
model=model_name,
|
|
3268
|
+
user_id=user_id,
|
|
3269
|
+
session_id=session_id,
|
|
3270
|
+
turn_id=turn_id,
|
|
3271
|
+
structured_output=structured_output_model,
|
|
3272
|
+
tool_calls=[],
|
|
3273
|
+
)
|
|
3274
|
+
ai_message.provider = "google_genai"
|
|
3275
|
+
|
|
3276
|
+
return ai_message
|
|
3277
|
+
|
|
3278
|
+
async def analyze_product_review(
|
|
3279
|
+
self,
|
|
3280
|
+
review_text: str,
|
|
3281
|
+
product_id: str,
|
|
3282
|
+
product_name: str,
|
|
3283
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
3284
|
+
temperature: Optional[float] = 0.1,
|
|
3285
|
+
user_id: Optional[str] = None,
|
|
3286
|
+
session_id: Optional[str] = None,
|
|
3287
|
+
use_structured: bool = True,
|
|
3288
|
+
) -> AIMessage:
|
|
3289
|
+
"""
|
|
3290
|
+
Analyze a product review and extract structured or unstructured information.
|
|
3291
|
+
|
|
3292
|
+
Args:
|
|
3293
|
+
review_text (str): The product review text to analyze.
|
|
3294
|
+
product_id (str): Unique identifier for the product.
|
|
3295
|
+
product_name (str): Name of the product being reviewed.
|
|
3296
|
+
model (Union[GoogleModel, str]): The model to use for the analysis.
|
|
3297
|
+
temperature (float): Sampling temperature for response generation.
|
|
3298
|
+
user_id (Optional[str]): Optional user identifier for tracking.
|
|
3299
|
+
session_id (Optional[str]): Optional session identifier for tracking.
|
|
3300
|
+
use_structured (bool): If True, forces a structured JSON output matching
|
|
3301
|
+
the ProductReview model. Defaults to True.
|
|
3302
|
+
"""
|
|
3303
|
+
self.logger.info(
|
|
3304
|
+
f"Analyzing product review for product_id: '{product_id}'"
|
|
3305
|
+
)
|
|
3306
|
+
|
|
3307
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
3308
|
+
turn_id = str(uuid.uuid4())
|
|
3309
|
+
|
|
3310
|
+
system_instruction = ""
|
|
3311
|
+
generation_config = {
|
|
3312
|
+
"max_output_tokens": self.max_tokens,
|
|
3313
|
+
"temperature": temperature or self.temperature,
|
|
3314
|
+
}
|
|
3315
|
+
structured_output_model = None
|
|
3316
|
+
|
|
3317
|
+
if use_structured:
|
|
3318
|
+
# Generate a prompt to force JSON output matching the Pydantic schema
|
|
3319
|
+
schema = ProductReview.model_json_schema()
|
|
3320
|
+
system_instruction = (
|
|
3321
|
+
"You are a product review analysis expert. Analyze the provided product review "
|
|
3322
|
+
"and extract the required information. Your response MUST be a valid JSON object "
|
|
3323
|
+
"that conforms to the following JSON Schema. Do not include any other text, "
|
|
3324
|
+
"explanations, or markdown formatting like ```json ... ``` around the JSON object.\n\n"
|
|
3325
|
+
f"JSON Schema:\n{self._json.dumps(schema)}"
|
|
3326
|
+
)
|
|
3327
|
+
# Enable Gemini's JSON mode for reliable structured output
|
|
3328
|
+
generation_config["response_mime_type"] = "application/json"
|
|
3329
|
+
structured_output_model = ProductReview
|
|
3330
|
+
else:
|
|
3331
|
+
# Generate a prompt for a more general, text-based analysis
|
|
3332
|
+
system_instruction = (
|
|
3333
|
+
"You are a product review analysis expert. Analyze the sentiment and key aspects "
|
|
3334
|
+
"of the following product review.\n"
|
|
3335
|
+
"Your response must include:\n"
|
|
3336
|
+
"1. Overall sentiment (Positive, Negative, or Neutral)\n"
|
|
3337
|
+
"2. Estimated Rating (on a scale of 1-5)\n"
|
|
3338
|
+
"3. Key Positive Points mentioned\n"
|
|
3339
|
+
"4. Key Negative Points mentioned\n"
|
|
3340
|
+
"5. A brief summary of the review's main points."
|
|
3341
|
+
)
|
|
3342
|
+
|
|
3343
|
+
# Build the user content part of the request
|
|
3344
|
+
user_prompt = (
|
|
3345
|
+
f"Product ID: {product_id}\n"
|
|
3346
|
+
f"Product Name: {product_name}\n"
|
|
3347
|
+
f"Review Text: \"{review_text}\""
|
|
3348
|
+
)
|
|
3349
|
+
contents = [{
|
|
3350
|
+
"role": "user",
|
|
3351
|
+
"parts": [{"text": user_prompt}]
|
|
3352
|
+
}]
|
|
3353
|
+
|
|
3354
|
+
# Finalize the generation configuration
|
|
3355
|
+
final_config = GenerateContentConfig(
|
|
3356
|
+
system_instruction={"role": "system", "parts": [{"text": system_instruction}]},
|
|
3357
|
+
tools=None,
|
|
3358
|
+
**generation_config
|
|
3359
|
+
)
|
|
3360
|
+
|
|
3361
|
+
# Make a stateless call to the model
|
|
3362
|
+
response = await self.client.aio.models.generate_content(
|
|
3363
|
+
model=model,
|
|
3364
|
+
contents=contents,
|
|
3365
|
+
config=final_config
|
|
3366
|
+
)
|
|
3367
|
+
|
|
3368
|
+
# Create the AIMessage response using the factory
|
|
3369
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3370
|
+
response=response,
|
|
3371
|
+
input_text=user_prompt, # Use the full prompt as input text
|
|
3372
|
+
model=model,
|
|
3373
|
+
user_id=user_id,
|
|
3374
|
+
session_id=session_id,
|
|
3375
|
+
turn_id=turn_id,
|
|
3376
|
+
structured_output=structured_output_model,
|
|
3377
|
+
tool_calls=[]
|
|
3378
|
+
)
|
|
3379
|
+
ai_message.provider = "google_genai"
|
|
3380
|
+
|
|
3381
|
+
return ai_message
|
|
3382
|
+
|
|
3383
|
+
async def image_generation(
|
|
3384
|
+
self,
|
|
3385
|
+
prompt_data: Union[str, ImageGenerationPrompt],
|
|
3386
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH_IMAGE_PREVIEW,
|
|
3387
|
+
temperature: Optional[float] = None,
|
|
3388
|
+
prompt_instruction: Optional[str] = None,
|
|
3389
|
+
reference_images: List[Optional[Path]] = None,
|
|
3390
|
+
output_directory: Optional[Path] = None,
|
|
3391
|
+
user_id: Optional[str] = None,
|
|
3392
|
+
session_id: Optional[str] = None,
|
|
3393
|
+
stateless: bool = True
|
|
3394
|
+
) -> AIMessage:
|
|
3395
|
+
"""
|
|
3396
|
+
Generates images based on a text prompt using Nano-Banana.
|
|
3397
|
+
"""
|
|
3398
|
+
if isinstance(prompt_data, str):
|
|
3399
|
+
prompt_data = ImageGenerationPrompt(
|
|
3400
|
+
prompt=prompt_data,
|
|
3401
|
+
model=model,
|
|
3402
|
+
)
|
|
3403
|
+
if prompt_data.model:
|
|
3404
|
+
model = GoogleModel.GEMINI_2_5_FLASH_IMAGE_PREVIEW.value
|
|
3405
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
3406
|
+
turn_id = str(uuid.uuid4())
|
|
3407
|
+
prompt_data.model = model
|
|
3408
|
+
|
|
3409
|
+
self.logger.info(
|
|
3410
|
+
f"Starting image generation with model: {model}"
|
|
3411
|
+
)
|
|
3412
|
+
|
|
3413
|
+
messages, conversation_session, _ = await self._prepare_conversation_context(
|
|
3414
|
+
prompt_data.prompt, None, user_id, session_id, None
|
|
3415
|
+
)
|
|
3416
|
+
|
|
3417
|
+
full_prompt = prompt_data.prompt
|
|
3418
|
+
if prompt_data.styles:
|
|
3419
|
+
full_prompt += ", " + ", ".join(prompt_data.styles)
|
|
3420
|
+
|
|
3421
|
+
# Prepare conversation history for Google GenAI format
|
|
3422
|
+
history = []
|
|
3423
|
+
if messages:
|
|
3424
|
+
for msg in messages[:-1]: # Exclude the current user message (last in list)
|
|
3425
|
+
role = msg['role'].lower()
|
|
3426
|
+
if role == 'user':
|
|
3427
|
+
parts = []
|
|
3428
|
+
for part_content in msg.get('content', []):
|
|
3429
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
3430
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
3431
|
+
if parts:
|
|
3432
|
+
history.append(UserContent(parts=parts))
|
|
3433
|
+
elif role in ['assistant', 'model']:
|
|
3434
|
+
parts = []
|
|
3435
|
+
for part_content in msg.get('content', []):
|
|
3436
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
3437
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
3438
|
+
if parts:
|
|
3439
|
+
history.append(ModelContent(parts=parts))
|
|
3440
|
+
|
|
3441
|
+
ref_images = []
|
|
3442
|
+
if reference_images:
|
|
3443
|
+
self.logger.info(
|
|
3444
|
+
f"Using reference image: {reference_images}"
|
|
3445
|
+
)
|
|
3446
|
+
for img_path in reference_images:
|
|
3447
|
+
if not img_path.exists():
|
|
3448
|
+
raise FileNotFoundError(
|
|
3449
|
+
f"Reference image not found: {img_path}"
|
|
3450
|
+
)
|
|
3451
|
+
# Load the reference image
|
|
3452
|
+
ref_images.append(Image.open(img_path))
|
|
3453
|
+
|
|
3454
|
+
config=types.GenerateContentConfig(
|
|
3455
|
+
response_modalities=['Text', 'Image'],
|
|
3456
|
+
temperature=temperature or self.temperature,
|
|
3457
|
+
system_instruction=prompt_instruction
|
|
3458
|
+
)
|
|
3459
|
+
|
|
3460
|
+
try:
|
|
3461
|
+
start_time = time.time()
|
|
3462
|
+
content = [full_prompt, *ref_images] if ref_images else [full_prompt]
|
|
3463
|
+
# Use the asynchronous client for image generation
|
|
3464
|
+
if stateless:
|
|
3465
|
+
response = await self.client.aio.models.generate_content(
|
|
3466
|
+
model=prompt_data.model,
|
|
3467
|
+
contents=content,
|
|
3468
|
+
config=config
|
|
3469
|
+
)
|
|
3470
|
+
else:
|
|
3471
|
+
# Create the stateful chat session
|
|
3472
|
+
chat = self.client.aio.chats.create(model=model, history=history, config=config)
|
|
3473
|
+
response = await chat.send_message(
|
|
3474
|
+
message=content,
|
|
3475
|
+
)
|
|
3476
|
+
execution_time = time.time() - start_time
|
|
3477
|
+
|
|
3478
|
+
pil_images = []
|
|
3479
|
+
saved_image_paths = []
|
|
3480
|
+
raw_response = {} # Initialize an empty dict for the raw response
|
|
3481
|
+
|
|
3482
|
+
raw_response['generated_images'] = []
|
|
3483
|
+
for part in response.candidates[0].content.parts:
|
|
3484
|
+
if part.text is not None:
|
|
3485
|
+
raw_response['text'] = part.text
|
|
3486
|
+
elif part.inline_data is not None:
|
|
3487
|
+
image = Image.open(io.BytesIO(part.inline_data.data))
|
|
3488
|
+
pil_images.append(image)
|
|
3489
|
+
if output_directory:
|
|
3490
|
+
if isinstance(output_directory, str):
|
|
3491
|
+
output_directory = Path(output_directory).resolve()
|
|
3492
|
+
file_path = self._save_image(image, output_directory)
|
|
3493
|
+
saved_image_paths.append(file_path)
|
|
3494
|
+
raw_response['generated_images'].append({
|
|
3495
|
+
'uri': file_path,
|
|
3496
|
+
'seed': None
|
|
3497
|
+
})
|
|
3498
|
+
|
|
3499
|
+
usage = CompletionUsage(execution_time=execution_time)
|
|
3500
|
+
if not stateless:
|
|
3501
|
+
await self._update_conversation_memory(
|
|
3502
|
+
user_id,
|
|
3503
|
+
session_id,
|
|
3504
|
+
conversation_session,
|
|
3505
|
+
messages + [
|
|
3506
|
+
{
|
|
3507
|
+
"role": "user",
|
|
3508
|
+
"content": [
|
|
3509
|
+
{"type": "text", "text": f"[Image Analysis]: {full_prompt}"}
|
|
3510
|
+
]
|
|
3511
|
+
},
|
|
3512
|
+
],
|
|
3513
|
+
None,
|
|
3514
|
+
turn_id,
|
|
3515
|
+
prompt_data.prompt,
|
|
3516
|
+
response.text,
|
|
3517
|
+
[]
|
|
3518
|
+
)
|
|
3519
|
+
ai_message = AIMessageFactory.from_imagen(
|
|
3520
|
+
output=pil_images,
|
|
3521
|
+
images=saved_image_paths,
|
|
3522
|
+
input=full_prompt,
|
|
3523
|
+
model=model,
|
|
3524
|
+
user_id=user_id,
|
|
3525
|
+
session_id=session_id,
|
|
3526
|
+
provider='nano-banana',
|
|
3527
|
+
usage=usage,
|
|
3528
|
+
raw_response=raw_response
|
|
3529
|
+
)
|
|
3530
|
+
return ai_message
|
|
3531
|
+
|
|
3532
|
+
except Exception as e:
|
|
3533
|
+
self.logger.error(f"Image generation failed: {e}")
|
|
3534
|
+
raise
|
|
3535
|
+
|
|
3536
|
+
def _upload_video(self, video_path: Union[str, Path]) -> str:
|
|
3537
|
+
"""
|
|
3538
|
+
Uploads a video file to Google GenAi Client.
|
|
3539
|
+
"""
|
|
3540
|
+
if isinstance(video_path, str):
|
|
3541
|
+
video_path = Path(video_path).resolve()
|
|
3542
|
+
if not video_path.exists():
|
|
3543
|
+
raise FileNotFoundError(
|
|
3544
|
+
f"Video file not found: {video_path}"
|
|
3545
|
+
)
|
|
3546
|
+
video_file = self.client.files.upload(
|
|
3547
|
+
file=video_path
|
|
3548
|
+
)
|
|
3549
|
+
while video_file.state == "PROCESSING":
|
|
3550
|
+
time.sleep(10)
|
|
3551
|
+
video_file = self.client.files.get(name=video_file.name)
|
|
3552
|
+
|
|
3553
|
+
if video_file.state == "FAILED":
|
|
3554
|
+
raise ValueError(video_file.state)
|
|
3555
|
+
|
|
3556
|
+
self.logger.debug(
|
|
3557
|
+
f"Uploaded video file: {video_file.uri}"
|
|
3558
|
+
)
|
|
3559
|
+
|
|
3560
|
+
return video_file
|
|
3561
|
+
|
|
3562
|
+
async def video_understanding(
|
|
3563
|
+
self,
|
|
3564
|
+
prompt: str,
|
|
3565
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
3566
|
+
temperature: Optional[float] = None,
|
|
3567
|
+
prompt_instruction: Optional[str] = None,
|
|
3568
|
+
video: Optional[Union[str, Path]] = None,
|
|
3569
|
+
user_id: Optional[str] = None,
|
|
3570
|
+
session_id: Optional[str] = None,
|
|
3571
|
+
stateless: bool = True,
|
|
3572
|
+
offsets: Optional[tuple[str, str]] = None,
|
|
3573
|
+
) -> AIMessage:
|
|
3574
|
+
"""
|
|
3575
|
+
Using a video (local or youtube) no analyze and extract information from videos.
|
|
3576
|
+
"""
|
|
3577
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
3578
|
+
turn_id = str(uuid.uuid4())
|
|
3579
|
+
|
|
3580
|
+
self.logger.info(
|
|
3581
|
+
f"Starting video analysis with model: {model}"
|
|
3582
|
+
)
|
|
3583
|
+
|
|
3584
|
+
if stateless:
|
|
3585
|
+
# For stateless mode, skip conversation memory
|
|
3586
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
3587
|
+
conversation_history = None
|
|
3588
|
+
else:
|
|
3589
|
+
# Use the unified conversation context preparation from AbstractClient
|
|
3590
|
+
messages, conversation_history, prompt_instruction = await self._prepare_conversation_context(
|
|
3591
|
+
prompt, None, user_id, session_id, prompt_instruction, stateless=stateless
|
|
3592
|
+
)
|
|
3593
|
+
|
|
3594
|
+
# Prepare conversation history for Google GenAI format
|
|
3595
|
+
history = []
|
|
3596
|
+
if messages:
|
|
3597
|
+
for msg in messages[:-1]: # Exclude the current user message (last in list)
|
|
3598
|
+
role = msg['role'].lower()
|
|
3599
|
+
if role == 'user':
|
|
3600
|
+
parts = []
|
|
3601
|
+
for part_content in msg.get('content', []):
|
|
3602
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
3603
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
3604
|
+
if parts:
|
|
3605
|
+
history.append(UserContent(parts=parts))
|
|
3606
|
+
elif role in ['assistant', 'model']:
|
|
3607
|
+
parts = []
|
|
3608
|
+
for part_content in msg.get('content', []):
|
|
3609
|
+
if isinstance(part_content, dict) and part_content.get('type') == 'text':
|
|
3610
|
+
parts.append(Part(text=part_content.get('text', '')))
|
|
3611
|
+
if parts:
|
|
3612
|
+
history.append(ModelContent(parts=parts))
|
|
3613
|
+
|
|
3614
|
+
config=types.GenerateContentConfig(
|
|
3615
|
+
response_modalities=['Text'],
|
|
3616
|
+
temperature=temperature or self.temperature,
|
|
3617
|
+
system_instruction=prompt_instruction
|
|
3618
|
+
)
|
|
3619
|
+
|
|
3620
|
+
if isinstance(video, str) and video.startswith("http"):
|
|
3621
|
+
# youtube video link:
|
|
3622
|
+
data = types.FileData(
|
|
3623
|
+
file_uri=video
|
|
3624
|
+
)
|
|
3625
|
+
video_metadata = None
|
|
3626
|
+
if offsets:
|
|
3627
|
+
video_metadata=types.VideoMetadata(
|
|
3628
|
+
start_offset=offsets[0],
|
|
3629
|
+
end_offset=offsets[1]
|
|
3630
|
+
)
|
|
3631
|
+
video_info = types.Part(
|
|
3632
|
+
file_data=data,
|
|
3633
|
+
video_metadata=video_metadata
|
|
3634
|
+
)
|
|
3635
|
+
else:
|
|
3636
|
+
video_info = self._upload_video(video)
|
|
3637
|
+
|
|
3638
|
+
try:
|
|
3639
|
+
start_time = time.time()
|
|
3640
|
+
content = [
|
|
3641
|
+
types.Part(
|
|
3642
|
+
text=prompt
|
|
3643
|
+
),
|
|
3644
|
+
video_info
|
|
3645
|
+
]
|
|
3646
|
+
# Use the asynchronous client for image generation
|
|
3647
|
+
if stateless:
|
|
3648
|
+
response = await self.client.aio.models.generate_content(
|
|
3649
|
+
model=model,
|
|
3650
|
+
contents=content,
|
|
3651
|
+
config=config
|
|
3652
|
+
)
|
|
3653
|
+
else:
|
|
3654
|
+
# Create the stateful chat session
|
|
3655
|
+
chat = self.client.aio.chats.create(model=model, history=history, config=config)
|
|
3656
|
+
response = await chat.send_message(
|
|
3657
|
+
message=content,
|
|
3658
|
+
)
|
|
3659
|
+
execution_time = time.time() - start_time
|
|
3660
|
+
|
|
3661
|
+
final_response = response.text
|
|
3662
|
+
|
|
3663
|
+
usage = CompletionUsage(execution_time=execution_time)
|
|
3664
|
+
|
|
3665
|
+
if not stateless:
|
|
3666
|
+
await self._update_conversation_memory(
|
|
3667
|
+
user_id,
|
|
3668
|
+
session_id,
|
|
3669
|
+
conversation_history,
|
|
3670
|
+
messages + [
|
|
3671
|
+
{
|
|
3672
|
+
"role": "user",
|
|
3673
|
+
"content": [
|
|
3674
|
+
{"type": "text", "text": f"[Image Analysis]: {prompt}"}
|
|
3675
|
+
]
|
|
3676
|
+
},
|
|
3677
|
+
],
|
|
3678
|
+
None,
|
|
3679
|
+
turn_id,
|
|
3680
|
+
prompt,
|
|
3681
|
+
final_response,
|
|
3682
|
+
[]
|
|
3683
|
+
)
|
|
3684
|
+
# Create AIMessage using factory
|
|
3685
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
3686
|
+
response=response,
|
|
3687
|
+
input_text=prompt,
|
|
3688
|
+
model=model,
|
|
3689
|
+
user_id=user_id,
|
|
3690
|
+
session_id=session_id,
|
|
3691
|
+
turn_id=turn_id,
|
|
3692
|
+
structured_output=final_response,
|
|
3693
|
+
tool_calls=None,
|
|
3694
|
+
conversation_history=conversation_history,
|
|
3695
|
+
text_response=final_response
|
|
3696
|
+
)
|
|
3697
|
+
|
|
3698
|
+
# Override provider to distinguish from Vertex AI
|
|
3699
|
+
ai_message.provider = "google_genai"
|
|
3700
|
+
|
|
3701
|
+
return ai_message
|
|
3702
|
+
|
|
3703
|
+
except Exception as e:
|
|
3704
|
+
self.logger.error(f"Image generation failed: {e}")
|
|
3705
|
+
raise
|
|
3706
|
+
|
|
3707
|
+
def _get_image_from_input(self, image: Union[str, Path, Image.Image]) -> Image.Image:
|
|
3708
|
+
"""Helper to consistently load an image into a PIL object."""
|
|
3709
|
+
if isinstance(image, (str, Path)):
|
|
3710
|
+
return Image.open(image).convert("RGB")
|
|
3711
|
+
elif isinstance(image, bytes):
|
|
3712
|
+
return Image.open(io.BytesIO(image)).convert("RGB")
|
|
3713
|
+
else:
|
|
3714
|
+
return image.convert("RGB")
|
|
3715
|
+
|
|
3716
|
+
def _crop_box(self, pil_img: Image.Image, box: DetectionBox) -> Image.Image:
|
|
3717
|
+
"""Crops a detection box from a PIL image with a small padding."""
|
|
3718
|
+
# A small padding can provide more context to the model
|
|
3719
|
+
pad = 8
|
|
3720
|
+
x1 = max(0, box.x1 - pad)
|
|
3721
|
+
y1 = max(0, box.y1 - pad)
|
|
3722
|
+
x2 = min(pil_img.width, box.x2 + pad)
|
|
3723
|
+
y2 = min(pil_img.height, box.y2 + pad)
|
|
3724
|
+
return pil_img.crop((x1, y1, x2, y2))
|
|
3725
|
+
|
|
3726
|
+
def _shelf_and_position(self, box: DetectionBox, regions: List[ShelfRegion]) -> Tuple[str, str]:
|
|
3727
|
+
"""
|
|
3728
|
+
Determines the shelf and position for a given detection box using a robust
|
|
3729
|
+
centroid-based assignment logic.
|
|
3730
|
+
"""
|
|
3731
|
+
if not regions:
|
|
3732
|
+
return "unknown", "center"
|
|
3733
|
+
|
|
3734
|
+
# --- NEW LOGIC: Use the object's center point for assignment ---
|
|
3735
|
+
center_y = box.y1 + (box.y2 - box.y1) / 2
|
|
3736
|
+
best_region = None
|
|
3737
|
+
|
|
3738
|
+
# 1. Primary Method: Find which shelf region CONTAINS the center point.
|
|
3739
|
+
for region in regions:
|
|
3740
|
+
if region.bbox.y1 <= center_y < region.bbox.y2:
|
|
3741
|
+
best_region = region
|
|
3742
|
+
break # Found the correct shelf
|
|
3743
|
+
|
|
3744
|
+
# 2. Fallback Method: If no shelf contains the center (edge case), find the closest one.
|
|
3745
|
+
if not best_region:
|
|
3746
|
+
min_distance = float('inf')
|
|
3747
|
+
for region in regions:
|
|
3748
|
+
shelf_center_y = region.bbox.y1 + (region.bbox.y2 - region.bbox.y1) / 2
|
|
3749
|
+
distance = abs(center_y - shelf_center_y)
|
|
3750
|
+
if distance < min_distance:
|
|
3751
|
+
min_distance = distance
|
|
3752
|
+
best_region = region
|
|
3753
|
+
|
|
3754
|
+
shelf = best_region.level if best_region else "unknown"
|
|
3755
|
+
|
|
3756
|
+
# --- Position logic remains the same, it's correct ---
|
|
3757
|
+
if best_region:
|
|
3758
|
+
box_center_x = (box.x1 + box.x2) / 2.0
|
|
3759
|
+
shelf_width = best_region.bbox.x2 - best_region.bbox.x1
|
|
3760
|
+
third_width = shelf_width / 3.0
|
|
3761
|
+
left_boundary = best_region.bbox.x1 + third_width
|
|
3762
|
+
right_boundary = best_region.bbox.x1 + 2 * third_width
|
|
3763
|
+
|
|
3764
|
+
if box_center_x < left_boundary:
|
|
3765
|
+
position = "left"
|
|
3766
|
+
elif box_center_x > right_boundary:
|
|
3767
|
+
position = "right"
|
|
3768
|
+
else:
|
|
3769
|
+
position = "center"
|
|
3770
|
+
else:
|
|
3771
|
+
position = "center"
|
|
3772
|
+
|
|
3773
|
+
return shelf, position
|
|
3774
|
+
|
|
3775
|
+
async def image_identification(
|
|
3776
|
+
self,
|
|
3777
|
+
prompt: str,
|
|
3778
|
+
image: Union[Path, bytes, Image.Image],
|
|
3779
|
+
detections: List[DetectionBox],
|
|
3780
|
+
shelf_regions: List[ShelfRegion],
|
|
3781
|
+
reference_images: Optional[Dict[str, Union[Path, bytes, Image.Image]]] = None,
|
|
3782
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_PRO,
|
|
3783
|
+
temperature: float = 0.0,
|
|
3784
|
+
user_id: Optional[str] = None,
|
|
3785
|
+
session_id: Optional[str] = None,
|
|
3786
|
+
) -> List[IdentifiedProduct]:
|
|
3787
|
+
"""
|
|
3788
|
+
Identify products using detected boxes, reference images, and Gemini Vision.
|
|
3789
|
+
|
|
3790
|
+
This method sends the full image, reference images, and individual crops of each
|
|
3791
|
+
detection to Gemini for precise identification, returning a structured list of
|
|
3792
|
+
IdentifiedProduct objects.
|
|
3793
|
+
|
|
3794
|
+
Args:
|
|
3795
|
+
image: The main image of the retail display.
|
|
3796
|
+
detections: A list of `DetectionBox` objects from the initial detection step.
|
|
3797
|
+
shelf_regions: A list of `ShelfRegion` objects defining shelf boundaries.
|
|
3798
|
+
reference_images: Optional list of images showing ideal products.
|
|
3799
|
+
model: The Gemini model to use, defaulting to Gemini 2.5 Pro for its advanced vision capabilities.
|
|
3800
|
+
temperature: The sampling temperature for the model's response.
|
|
3801
|
+
|
|
3802
|
+
Returns:
|
|
3803
|
+
A list of `IdentifiedProduct` objects with detailed identification info.
|
|
3804
|
+
"""
|
|
3805
|
+
self.logger.info(f"Starting Gemini identification for {len(detections)} detections.")
|
|
3806
|
+
model_name = model.value if isinstance(model, GoogleModel) else model
|
|
3807
|
+
|
|
3808
|
+
# --- 1. Prepare Images and Metadata ---
|
|
3809
|
+
main_image_pil = self._get_image_from_input(image)
|
|
3810
|
+
detection_details = []
|
|
3811
|
+
id_to_details = {}
|
|
3812
|
+
for i, det in enumerate(detections, start=1):
|
|
3813
|
+
shelf, pos = self._shelf_and_position(det, shelf_regions)
|
|
3814
|
+
detection_details.append({
|
|
3815
|
+
"id": i,
|
|
3816
|
+
"detection": det,
|
|
3817
|
+
"shelf": shelf,
|
|
3818
|
+
"position": pos,
|
|
3819
|
+
"crop": self._crop_box(main_image_pil, det),
|
|
3820
|
+
})
|
|
3821
|
+
id_to_details[i] = {"shelf": shelf, "position": pos, "detection": det}
|
|
3822
|
+
|
|
3823
|
+
# --- 2. Construct the Multi-Modal Prompt for Gemini ---
|
|
3824
|
+
# The prompt is a list of parts: text instructions, reference images,
|
|
3825
|
+
# the main image, and finally the individual crops.
|
|
3826
|
+
contents = [Part(text=prompt)] # Start with the user-provided prompt
|
|
3827
|
+
|
|
3828
|
+
# --- Create a lookup map from ID to pre-calculated details ---
|
|
3829
|
+
id_to_details = {}
|
|
3830
|
+
for i, det in enumerate(detections, 1):
|
|
3831
|
+
shelf, pos = self._shelf_and_position(det, shelf_regions)
|
|
3832
|
+
id_to_details[i] = {"shelf": shelf, "position": pos, "detection": det}
|
|
3833
|
+
|
|
3834
|
+
if reference_images:
|
|
3835
|
+
# Add a text part to introduce the references
|
|
3836
|
+
contents.append(Part(text="\n\n--- REFERENCE IMAGE GUIDE ---"))
|
|
3837
|
+
for label, ref_img_input in reference_images.items():
|
|
3838
|
+
# Add the label text, then the image
|
|
3839
|
+
contents.append(Part(text=f"Reference for '{label}':"))
|
|
3840
|
+
contents.append(self._get_image_from_input(ref_img_input))
|
|
3841
|
+
contents.append(Part(text="--- END REFERENCE GUIDE ---"))
|
|
3842
|
+
|
|
3843
|
+
# Add the main image for overall context
|
|
3844
|
+
contents.append(main_image_pil)
|
|
3845
|
+
|
|
3846
|
+
# Add each cropped detection image
|
|
3847
|
+
for item in detection_details:
|
|
3848
|
+
contents.append(item['crop'])
|
|
3849
|
+
|
|
3850
|
+
for i, det in enumerate(detections, 1):
|
|
3851
|
+
contents.append(self._crop_box(main_image_pil, det))
|
|
3852
|
+
|
|
3853
|
+
# Manually generate the JSON schema from the Pydantic model
|
|
3854
|
+
raw_schema = IdentificationResponse.model_json_schema()
|
|
3855
|
+
# Clean the schema to remove unsupported properties like 'additionalProperties'
|
|
3856
|
+
_schema = self.clean_google_schema(raw_schema)
|
|
3857
|
+
|
|
3858
|
+
# --- 3. Configure the API Call for Structured Output ---
|
|
3859
|
+
generation_config = GenerateContentConfig(
|
|
3860
|
+
temperature=temperature,
|
|
3861
|
+
max_output_tokens=8192, # Generous limit for JSON with many items
|
|
3862
|
+
response_mime_type="application/json",
|
|
3863
|
+
response_schema=_schema,
|
|
3864
|
+
)
|
|
3865
|
+
|
|
3866
|
+
# --- 4. Call Gemini and Process the Response ---
|
|
3867
|
+
try:
|
|
3868
|
+
response = await self.client.aio.models.generate_content(
|
|
3869
|
+
model=model_name,
|
|
3870
|
+
contents=contents,
|
|
3871
|
+
config=generation_config,
|
|
3872
|
+
)
|
|
3873
|
+
except Exception as e:
|
|
3874
|
+
# if is 503 UNAVAILABLE. {'error': {'code': 503, 'message': 'The model is overloaded. Please try again later.', 'status': 'UNAVAILABLE'}}
|
|
3875
|
+
# then, retry with a short delay but chaing to use gemini-2,5-flash instead pro.
|
|
3876
|
+
await asyncio.sleep(1.5)
|
|
3877
|
+
response = await self.client.aio.models.generate_content(
|
|
3878
|
+
model='gemini-2.5-flash',
|
|
3879
|
+
contents=contents,
|
|
3880
|
+
config=generation_config,
|
|
3881
|
+
)
|
|
3882
|
+
|
|
3883
|
+
try:
|
|
3884
|
+
response_text = self._safe_extract_text(response)
|
|
3885
|
+
if not response_text:
|
|
3886
|
+
raise ValueError(
|
|
3887
|
+
"Received an empty response from the model."
|
|
3888
|
+
)
|
|
3889
|
+
|
|
3890
|
+
print('RAW RESPONSE:', response_text)
|
|
3891
|
+
# The model output should conform to the Pydantic model directly
|
|
3892
|
+
parsed_data = IdentificationResponse.model_validate_json(response_text)
|
|
3893
|
+
identified_items = parsed_data.identified_products
|
|
3894
|
+
|
|
3895
|
+
# --- 5. Link LLM results back to original detections ---
|
|
3896
|
+
final_products = []
|
|
3897
|
+
for item in identified_items:
|
|
3898
|
+
# Case 1: Item was pre-detected (has a positive ID)
|
|
3899
|
+
if item.detection_id is not None and item.detection_id > 0 and item.detection_id in id_to_details:
|
|
3900
|
+
details = id_to_details[item.detection_id]
|
|
3901
|
+
item.detection_box = details["detection"]
|
|
3902
|
+
|
|
3903
|
+
# Only use geometric fallback if LLM didn't provide shelf_location
|
|
3904
|
+
if not item.shelf_location:
|
|
3905
|
+
self.logger.warning(
|
|
3906
|
+
f"LLM did not provide shelf_location for ID {item.detection_id}. Using geometric fallback."
|
|
3907
|
+
)
|
|
3908
|
+
item.shelf_location = details["shelf"]
|
|
3909
|
+
if not item.position_on_shelf:
|
|
3910
|
+
item.position_on_shelf = details["position"]
|
|
3911
|
+
final_products.append(item)
|
|
3912
|
+
|
|
3913
|
+
# Case 2: Item was newly found by the LLM
|
|
3914
|
+
elif item.detection_id is None:
|
|
3915
|
+
if item.detection_box:
|
|
3916
|
+
# TRUST the LLM's assignment, only use geometric fallback if missing
|
|
3917
|
+
if not item.shelf_location:
|
|
3918
|
+
self.logger.info(f"LLM didn't provide shelf_location, calculating geometrically")
|
|
3919
|
+
shelf, pos = self._shelf_and_position(item.detection_box, shelf_regions)
|
|
3920
|
+
item.shelf_location = shelf
|
|
3921
|
+
item.position_on_shelf = pos
|
|
3922
|
+
else:
|
|
3923
|
+
# LLM provided shelf_location, trust it but calculate position if missing
|
|
3924
|
+
self.logger.info(f"Using LLM-assigned shelf_location: {item.shelf_location}")
|
|
3925
|
+
if not item.position_on_shelf:
|
|
3926
|
+
_, pos = self._shelf_and_position(item.detection_box, shelf_regions)
|
|
3927
|
+
item.position_on_shelf = pos
|
|
3928
|
+
|
|
3929
|
+
self.logger.info(
|
|
3930
|
+
f"Adding new object found by LLM: {item.product_type} on shelf '{item.shelf_location}'"
|
|
3931
|
+
)
|
|
3932
|
+
final_products.append(item)
|
|
3933
|
+
|
|
3934
|
+
# Case 3: Item was newly found by the LLM (has a negative ID from our validator)
|
|
3935
|
+
elif item.detection_id < 0:
|
|
3936
|
+
if item.detection_box:
|
|
3937
|
+
# TRUST the LLM's assignment, only use geometric fallback if missing
|
|
3938
|
+
if not item.shelf_location:
|
|
3939
|
+
self.logger.info(f"LLM didn't provide shelf_location, calculating geometrically")
|
|
3940
|
+
shelf, pos = self._shelf_and_position(item.detection_box, shelf_regions)
|
|
3941
|
+
item.shelf_location = shelf
|
|
3942
|
+
item.position_on_shelf = pos
|
|
3943
|
+
else:
|
|
3944
|
+
# LLM provided shelf_location, trust it but calculate position if missing
|
|
3945
|
+
self.logger.info(f"Using LLM-assigned shelf_location: {item.shelf_location}")
|
|
3946
|
+
if not item.position_on_shelf:
|
|
3947
|
+
_, pos = self._shelf_and_position(item.detection_box, shelf_regions)
|
|
3948
|
+
item.position_on_shelf = pos
|
|
3949
|
+
|
|
3950
|
+
self.logger.info(f"Adding new object found by LLM: {item.product_type} on shelf '{item.shelf_location}'")
|
|
3951
|
+
final_products.append(item)
|
|
3952
|
+
else:
|
|
3953
|
+
self.logger.warning(
|
|
3954
|
+
f"LLM-found item with ID '{item.detection_id}' is missing a detection_box, skipping."
|
|
3955
|
+
)
|
|
3956
|
+
|
|
3957
|
+
self.logger.info(
|
|
3958
|
+
f"Successfully identified {len(final_products)} products."
|
|
3959
|
+
)
|
|
3960
|
+
return final_products
|
|
3961
|
+
|
|
3962
|
+
except Exception as e:
|
|
3963
|
+
self.logger.error(
|
|
3964
|
+
f"Gemini image identification failed: {e}"
|
|
3965
|
+
)
|
|
3966
|
+
# Fallback to creating simple products from initial detections
|
|
3967
|
+
fallback_products = []
|
|
3968
|
+
for item in detection_details:
|
|
3969
|
+
shelf, pos = item["shelf"], item["position"]
|
|
3970
|
+
det = item["detection"]
|
|
3971
|
+
fallback_products.append(IdentifiedProduct(
|
|
3972
|
+
detection_box=det,
|
|
3973
|
+
detection_id=item['id'],
|
|
3974
|
+
product_type=det.class_name,
|
|
3975
|
+
product_model=None,
|
|
3976
|
+
confidence=det.confidence * 0.5, # Lower confidence for fallback
|
|
3977
|
+
visual_features=["fallback_identification"],
|
|
3978
|
+
reference_match="none",
|
|
3979
|
+
shelf_location=shelf,
|
|
3980
|
+
position_on_shelf=pos
|
|
3981
|
+
))
|
|
3982
|
+
return fallback_products
|
|
3983
|
+
|
|
3984
|
+
async def create_speech(
|
|
3985
|
+
self,
|
|
3986
|
+
content: str,
|
|
3987
|
+
voice_name: Optional[str] = 'charon',
|
|
3988
|
+
model: Union[str, GoogleModel] = GoogleModel.GEMINI_2_5_FLASH,
|
|
3989
|
+
output_directory: Optional[Path] = None,
|
|
3990
|
+
only_script: bool = False,
|
|
3991
|
+
script_file: str = "narration_script.txt",
|
|
3992
|
+
podcast_file: str= "generated_podcast.wav",
|
|
3993
|
+
mime_format: str = "audio/wav",
|
|
3994
|
+
user_id: Optional[str] = None,
|
|
3995
|
+
session_id: Optional[str] = None,
|
|
3996
|
+
max_retries: int = 3,
|
|
3997
|
+
retry_delay: float = 1.0,
|
|
3998
|
+
language: str = "en-US"
|
|
3999
|
+
) -> AIMessage:
|
|
4000
|
+
"""
|
|
4001
|
+
Generates a simple narrative script from text and then converts it to speech.
|
|
4002
|
+
This is a simpler, two-step process for text-to-speech generation.
|
|
4003
|
+
|
|
4004
|
+
Args:
|
|
4005
|
+
content (str): The text content to generate speech from.
|
|
4006
|
+
voice_name (Optional[str]): The name of the voice to use. Defaults to 'charon'.
|
|
4007
|
+
model (Union[str, GoogleModel]): The model for the text-to-text step.
|
|
4008
|
+
output_directory (Optional[Path]): Directory to save the audio file.
|
|
4009
|
+
mime_format (str): The audio format, e.g., 'audio/wav'.
|
|
4010
|
+
user_id (Optional[str]): Optional user identifier.
|
|
4011
|
+
session_id (Optional[str]): Optional session identifier.
|
|
4012
|
+
max_retries (int): Maximum network retries.
|
|
4013
|
+
retry_delay (float): Delay for retries.
|
|
4014
|
+
|
|
4015
|
+
Returns:
|
|
4016
|
+
An AIMessage object containing the generated audio, the text script, and metadata.
|
|
4017
|
+
"""
|
|
4018
|
+
self.logger.info(
|
|
4019
|
+
"Starting a two-step text-to-speech process."
|
|
4020
|
+
)
|
|
4021
|
+
# Step 1: Generate a simple, narrated script from the provided text.
|
|
4022
|
+
system_prompt = f"""
|
|
4023
|
+
You are a professional scriptwriter. Given the input text, generate a clear, narrative style, suitable for a voiceover.
|
|
4024
|
+
|
|
4025
|
+
**Instructions:**
|
|
4026
|
+
- The conversation should be engaging, natural, and suitable for a TTS system.
|
|
4027
|
+
- The script should be formatted for TTS, with clear speaker lines.
|
|
4028
|
+
"""
|
|
4029
|
+
script_prompt = f"""
|
|
4030
|
+
Read the following text in a clear, narrative style, suitable for a voiceover.
|
|
4031
|
+
Ensure the tone is neutral and professional. Do not add any conversational
|
|
4032
|
+
elements. Just read the text.
|
|
4033
|
+
|
|
4034
|
+
Text:
|
|
4035
|
+
---
|
|
4036
|
+
{content}
|
|
4037
|
+
---
|
|
4038
|
+
"""
|
|
4039
|
+
script_text = ''
|
|
4040
|
+
script_response = None
|
|
4041
|
+
try:
|
|
4042
|
+
script_response = await self.ask(
|
|
4043
|
+
prompt=script_prompt,
|
|
4044
|
+
model=model,
|
|
4045
|
+
system_prompt=system_prompt,
|
|
4046
|
+
temperature=0.0,
|
|
4047
|
+
stateless=True,
|
|
4048
|
+
use_tools=False,
|
|
4049
|
+
)
|
|
4050
|
+
script_text = script_response.output
|
|
4051
|
+
except Exception as e:
|
|
4052
|
+
self.logger.error(f"Script generation failed: {e}")
|
|
4053
|
+
raise SpeechGenerationError(
|
|
4054
|
+
f"Script generation failed: {str(e)}"
|
|
4055
|
+
) from e
|
|
4056
|
+
|
|
4057
|
+
if not script_text:
|
|
4058
|
+
raise SpeechGenerationError(
|
|
4059
|
+
"Script generation failed, could not proceed with speech generation."
|
|
4060
|
+
)
|
|
4061
|
+
|
|
4062
|
+
self.logger.info(f"Generated script text successfully.")
|
|
4063
|
+
saved_file_paths = []
|
|
4064
|
+
if only_script:
|
|
4065
|
+
# If only the script is needed, save it and return it in an AIMessage
|
|
4066
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
4067
|
+
script_path = output_directory / script_file
|
|
4068
|
+
try:
|
|
4069
|
+
async with aiofiles.open(script_path, "w", encoding="utf-8") as f:
|
|
4070
|
+
await f.write(script_text)
|
|
4071
|
+
self.logger.info(
|
|
4072
|
+
f"Saved narration script to {script_path}"
|
|
4073
|
+
)
|
|
4074
|
+
saved_file_paths.append(script_path)
|
|
4075
|
+
except Exception as e:
|
|
4076
|
+
self.logger.error(f"Failed to save script file: {e}")
|
|
4077
|
+
ai_message = AIMessageFactory.from_gemini(
|
|
4078
|
+
response=script_response,
|
|
4079
|
+
text_response=script_text,
|
|
4080
|
+
input_text=content,
|
|
4081
|
+
model=model if isinstance(model, str) else model.value,
|
|
4082
|
+
user_id=user_id,
|
|
4083
|
+
session_id=session_id,
|
|
4084
|
+
files=saved_file_paths
|
|
4085
|
+
)
|
|
4086
|
+
return ai_message
|
|
4087
|
+
|
|
4088
|
+
# Step 2: Generate speech from the generated script.
|
|
4089
|
+
speech_config_data = SpeechGenerationPrompt(
|
|
4090
|
+
prompt=script_text,
|
|
4091
|
+
speakers=[
|
|
4092
|
+
SpeakerConfig(
|
|
4093
|
+
name="narrator",
|
|
4094
|
+
voice=voice_name,
|
|
4095
|
+
)
|
|
4096
|
+
],
|
|
4097
|
+
language=language
|
|
4098
|
+
)
|
|
4099
|
+
|
|
4100
|
+
# Use the existing core logic to generate the audio
|
|
4101
|
+
model = GoogleModel.GEMINI_2_5_FLASH_TTS.value
|
|
4102
|
+
|
|
4103
|
+
speaker = speech_config_data.speakers[0]
|
|
4104
|
+
final_voice = speaker.voice
|
|
4105
|
+
|
|
4106
|
+
speech_config = types.SpeechConfig(
|
|
4107
|
+
voice_config=types.VoiceConfig(
|
|
4108
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
4109
|
+
voice_name=final_voice
|
|
4110
|
+
)
|
|
4111
|
+
),
|
|
4112
|
+
language_code=speech_config_data.language or "en-US"
|
|
4113
|
+
)
|
|
4114
|
+
|
|
4115
|
+
config = types.GenerateContentConfig(
|
|
4116
|
+
response_modalities=["AUDIO"],
|
|
4117
|
+
speech_config=speech_config,
|
|
4118
|
+
temperature=0.7
|
|
4119
|
+
)
|
|
4120
|
+
|
|
4121
|
+
for attempt in range(max_retries + 1):
|
|
4122
|
+
try:
|
|
4123
|
+
if attempt > 0:
|
|
4124
|
+
delay = retry_delay * (2 ** (attempt - 1))
|
|
4125
|
+
self.logger.info(
|
|
4126
|
+
f"Retrying speech (attempt {attempt + 1}/{max_retries + 1}) after {delay}s delay..."
|
|
4127
|
+
)
|
|
4128
|
+
await asyncio.sleep(delay)
|
|
4129
|
+
start_time = time.time()
|
|
4130
|
+
response = await self.client.aio.models.generate_content(
|
|
4131
|
+
model=model,
|
|
4132
|
+
contents=speech_config_data.prompt,
|
|
4133
|
+
config=config,
|
|
4134
|
+
)
|
|
4135
|
+
execution_time = time.time() - start_time
|
|
4136
|
+
audio_data = self._extract_audio_data(response)
|
|
4137
|
+
if audio_data is None:
|
|
4138
|
+
raise SpeechGenerationError(
|
|
4139
|
+
"No audio data found in response. The speech generation may have failed."
|
|
4140
|
+
)
|
|
4141
|
+
|
|
4142
|
+
saved_file_paths = []
|
|
4143
|
+
if output_directory:
|
|
4144
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
4145
|
+
podcast_path = output_directory / podcast_file
|
|
4146
|
+
script_path = output_directory / script_file
|
|
4147
|
+
self._save_audio_file(audio_data, podcast_path, mime_format)
|
|
4148
|
+
saved_file_paths.append(podcast_path)
|
|
4149
|
+
try:
|
|
4150
|
+
async with aiofiles.open(script_path, "w", encoding="utf-8") as f:
|
|
4151
|
+
await f.write(script_text)
|
|
4152
|
+
self.logger.info(f"Saved narration script to {script_path}")
|
|
4153
|
+
saved_file_paths.append(script_path)
|
|
4154
|
+
except Exception as e:
|
|
4155
|
+
self.logger.error(f"Failed to save script file: {e}")
|
|
4156
|
+
|
|
4157
|
+
usage = CompletionUsage(
|
|
4158
|
+
execution_time=execution_time,
|
|
4159
|
+
input_tokens=len(script_text),
|
|
4160
|
+
)
|
|
4161
|
+
|
|
4162
|
+
ai_message = AIMessageFactory.from_speech(
|
|
4163
|
+
output=audio_data,
|
|
4164
|
+
files=saved_file_paths,
|
|
4165
|
+
input=script_text,
|
|
4166
|
+
model=model,
|
|
4167
|
+
provider="google_genai",
|
|
4168
|
+
documents=[script_path],
|
|
4169
|
+
usage=usage,
|
|
4170
|
+
user_id=user_id,
|
|
4171
|
+
session_id=session_id,
|
|
4172
|
+
raw_response=None
|
|
4173
|
+
)
|
|
4174
|
+
return ai_message
|
|
4175
|
+
|
|
4176
|
+
except (
|
|
4177
|
+
aiohttp.ClientPayloadError,
|
|
4178
|
+
aiohttp.ClientConnectionError,
|
|
4179
|
+
aiohttp.ClientResponseError,
|
|
4180
|
+
aiohttp.ServerTimeoutError,
|
|
4181
|
+
ConnectionResetError,
|
|
4182
|
+
TimeoutError,
|
|
4183
|
+
asyncio.TimeoutError
|
|
4184
|
+
) as network_error:
|
|
4185
|
+
if attempt < max_retries:
|
|
4186
|
+
self.logger.warning(
|
|
4187
|
+
f"Network error on attempt {attempt + 1}: {str(network_error)}. Retrying..."
|
|
4188
|
+
)
|
|
4189
|
+
continue
|
|
4190
|
+
else:
|
|
4191
|
+
self.logger.error(
|
|
4192
|
+
f"Speech generation failed after {max_retries + 1} attempts"
|
|
4193
|
+
)
|
|
4194
|
+
raise SpeechGenerationError(
|
|
4195
|
+
f"Speech generation failed after {max_retries + 1} attempts. "
|
|
4196
|
+
f"Last error: {str(network_error)}."
|
|
4197
|
+
) from network_error
|
|
4198
|
+
|
|
4199
|
+
except Exception as e:
|
|
4200
|
+
self.logger.error(
|
|
4201
|
+
f"Speech generation failed with non-retryable error: {str(e)}"
|
|
4202
|
+
)
|
|
4203
|
+
raise SpeechGenerationError(
|
|
4204
|
+
f"Speech generation failed: {str(e)}"
|
|
4205
|
+
) from e
|
|
4206
|
+
|
|
4207
|
+
async def video_generation(
|
|
4208
|
+
self,
|
|
4209
|
+
prompt_data: Union[str, VideoGenerationPrompt],
|
|
4210
|
+
model: Union[str, GoogleModel] = GoogleModel.VEO_3_0,
|
|
4211
|
+
reference_image: Optional[Path] = None,
|
|
4212
|
+
generate_image_first: bool = False,
|
|
4213
|
+
image_prompt: Optional[str] = None,
|
|
4214
|
+
image_generation_model: str = "imagen-4.0-generate-001",
|
|
4215
|
+
aspect_ratio: Optional[str] = None,
|
|
4216
|
+
resolution: Optional[str] = None,
|
|
4217
|
+
negative_prompt: Optional[str] = None,
|
|
4218
|
+
output_directory: Optional[Path] = None,
|
|
4219
|
+
user_id: Optional[str] = None,
|
|
4220
|
+
session_id: Optional[str] = None,
|
|
4221
|
+
stateless: bool = True,
|
|
4222
|
+
poll_interval: int = 10
|
|
4223
|
+
) -> AIMessage:
|
|
4224
|
+
"""
|
|
4225
|
+
Generates videos based on a text prompt using Veo models.
|
|
4226
|
+
|
|
4227
|
+
Args:
|
|
4228
|
+
prompt_data: Text prompt or VideoGenerationPrompt object
|
|
4229
|
+
model: Video generation model (VEO_2_0 or VEO_3_0)
|
|
4230
|
+
reference_image: Optional path to reference image. If provided, this takes precedence.
|
|
4231
|
+
generate_image_first: If True and no reference_image, generates an image with Imagen first
|
|
4232
|
+
image_generation_model: Model to use for image generation (default: imagen-4.0-generate-001)
|
|
4233
|
+
aspect_ratio: Video aspect ratio (e.g., "16:9", "9:16"). Overrides prompt_data setting.
|
|
4234
|
+
resolution: Video resolution (e.g., "720p", "1080p"). Overrides prompt_data setting.
|
|
4235
|
+
negative_prompt: What to avoid in the video. Overrides prompt_data setting.
|
|
4236
|
+
output_directory: Directory to save generated videos
|
|
4237
|
+
user_id: User ID for conversation tracking
|
|
4238
|
+
session_id: Session ID for conversation tracking
|
|
4239
|
+
stateless: If True, no conversation memory is saved
|
|
4240
|
+
poll_interval: Seconds between polling checks (default: 10)
|
|
4241
|
+
|
|
4242
|
+
Returns:
|
|
4243
|
+
AIMessage containing the generated video
|
|
4244
|
+
"""
|
|
4245
|
+
# Parse prompt data
|
|
4246
|
+
if isinstance(prompt_data, str):
|
|
4247
|
+
prompt_data = VideoGenerationPrompt(
|
|
4248
|
+
prompt=prompt_data,
|
|
4249
|
+
model=model.value if isinstance(model, GoogleModel) else model,
|
|
4250
|
+
)
|
|
4251
|
+
|
|
4252
|
+
# Validate and set model
|
|
4253
|
+
if prompt_data.model:
|
|
4254
|
+
model = prompt_data.model
|
|
4255
|
+
model = model.value if isinstance(model, GoogleModel) else model
|
|
4256
|
+
|
|
4257
|
+
if model not in [GoogleModel.VEO_2_0.value, GoogleModel.VEO_3_0.value, GoogleModel.VEO_3_0_FAST.value]:
|
|
4258
|
+
raise ValueError(
|
|
4259
|
+
f"Video generation only supported with VEO 2.0 or VEO 3.0 models. Got: {model}"
|
|
4260
|
+
)
|
|
4261
|
+
|
|
4262
|
+
# Setup output directory
|
|
4263
|
+
if output_directory:
|
|
4264
|
+
if isinstance(output_directory, str):
|
|
4265
|
+
output_directory = Path(output_directory).resolve()
|
|
4266
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
4267
|
+
else:
|
|
4268
|
+
output_directory = BASE_DIR.joinpath('static', 'generated_videos')
|
|
4269
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
4270
|
+
|
|
4271
|
+
turn_id = str(uuid.uuid4())
|
|
4272
|
+
|
|
4273
|
+
self.logger.info(
|
|
4274
|
+
f"Starting video generation with model: {model}"
|
|
4275
|
+
)
|
|
4276
|
+
|
|
4277
|
+
# Prepare conversation context if not stateless
|
|
4278
|
+
if not stateless:
|
|
4279
|
+
messages, conversation_session, _ = await self._prepare_conversation_context(
|
|
4280
|
+
prompt_data.prompt, None, user_id, session_id, None
|
|
4281
|
+
)
|
|
4282
|
+
else:
|
|
4283
|
+
messages = None
|
|
4284
|
+
conversation_session = None
|
|
4285
|
+
|
|
4286
|
+
# Override prompt settings with explicit parameters
|
|
4287
|
+
final_aspect_ratio = aspect_ratio or prompt_data.aspect_ratio or "16:9"
|
|
4288
|
+
final_resolution = resolution or getattr(prompt_data, 'resolution', None) or "720p"
|
|
4289
|
+
final_negative_prompt = negative_prompt or prompt_data.negative_prompt or ""
|
|
4290
|
+
|
|
4291
|
+
# Step 1: Handle image input (reference or generate)
|
|
4292
|
+
generated_image = None
|
|
4293
|
+
image_for_video = None
|
|
4294
|
+
|
|
4295
|
+
if reference_image:
|
|
4296
|
+
self.logger.info(
|
|
4297
|
+
f"Using reference image: {reference_image}"
|
|
4298
|
+
)
|
|
4299
|
+
if not reference_image.exists():
|
|
4300
|
+
raise FileNotFoundError(f"Reference image not found: {reference_image}")
|
|
4301
|
+
|
|
4302
|
+
# VEO 3.0 doesn't support reference images, fall back to VEO 2.0
|
|
4303
|
+
# if model == GoogleModel.VEO_3_0.value:
|
|
4304
|
+
# self.logger.warning(
|
|
4305
|
+
# "VEO 3.0 does not support reference images. Switching to VEO 2.0."
|
|
4306
|
+
# )
|
|
4307
|
+
# model = GoogleModel.VEO_3_0_FAST
|
|
4308
|
+
|
|
4309
|
+
# Load reference image
|
|
4310
|
+
ref_image_pil = Image.open(reference_image)
|
|
4311
|
+
# Convert PIL Image to bytes for Google GenAI API
|
|
4312
|
+
img_byte_arr = io.BytesIO()
|
|
4313
|
+
ref_image_pil.save(img_byte_arr, format=ref_image_pil.format or 'JPEG')
|
|
4314
|
+
img_byte_arr.seek(0)
|
|
4315
|
+
image_bytes = img_byte_arr.getvalue()
|
|
4316
|
+
|
|
4317
|
+
image_for_video = types.Image(
|
|
4318
|
+
image_bytes=image_bytes,
|
|
4319
|
+
mime_type=f"image/{(ref_image_pil.format or 'jpeg').lower()}"
|
|
4320
|
+
)
|
|
4321
|
+
|
|
4322
|
+
elif generate_image_first:
|
|
4323
|
+
self.logger.info(
|
|
4324
|
+
f"Generating image first with {image_generation_model} before video generation"
|
|
4325
|
+
)
|
|
4326
|
+
|
|
4327
|
+
try:
|
|
4328
|
+
# Generate image using Imagen
|
|
4329
|
+
image_config = types.GenerateImagesConfig(
|
|
4330
|
+
number_of_images=1,
|
|
4331
|
+
output_mime_type="image/jpeg",
|
|
4332
|
+
aspect_ratio=final_aspect_ratio
|
|
4333
|
+
)
|
|
4334
|
+
|
|
4335
|
+
gen_prompt = image_prompt or prompt_data.prompt
|
|
4336
|
+
|
|
4337
|
+
image_response = await self.client.aio.models.generate_images(
|
|
4338
|
+
model=image_generation_model,
|
|
4339
|
+
prompt=gen_prompt,
|
|
4340
|
+
config=image_config
|
|
4341
|
+
)
|
|
4342
|
+
|
|
4343
|
+
if image_response.generated_images:
|
|
4344
|
+
generated_image = image_response.generated_images[0]
|
|
4345
|
+
self.logger.info(
|
|
4346
|
+
"Successfully generated reference image for video"
|
|
4347
|
+
)
|
|
4348
|
+
|
|
4349
|
+
# Convert generated image to format needed for video generation
|
|
4350
|
+
pil_image = generated_image.image
|
|
4351
|
+
# can we use directly because is a google.genai.types.Image
|
|
4352
|
+
image_for_video = pil_image
|
|
4353
|
+
# Also, save the generated image to output directory:
|
|
4354
|
+
gen_image_path = output_directory / f"generated_image_{turn_id}.jpg"
|
|
4355
|
+
pil_image.save(gen_image_path)
|
|
4356
|
+
self.logger.info(
|
|
4357
|
+
f"Saved generated reference image to: {gen_image_path}"
|
|
4358
|
+
)
|
|
4359
|
+
|
|
4360
|
+
# VEO 3.0 doesn't support reference images
|
|
4361
|
+
if model == GoogleModel.VEO_3_0.value:
|
|
4362
|
+
self.logger.warning(
|
|
4363
|
+
"VEO 3.0 does not support reference images. Switching to VEO 3.0 FAST"
|
|
4364
|
+
)
|
|
4365
|
+
model = GoogleModel.VEO_3_0_FAST
|
|
4366
|
+
else:
|
|
4367
|
+
raise Exception("Image generation returned no images")
|
|
4368
|
+
|
|
4369
|
+
except Exception as e:
|
|
4370
|
+
self.logger.error(f"Image generation failed: {e}")
|
|
4371
|
+
raise Exception(f"Failed to generate reference image: {e}")
|
|
4372
|
+
|
|
4373
|
+
# Step 2: Generate video
|
|
4374
|
+
self.logger.info(f"Generating video with prompt: '{prompt_data.prompt[:100]}...'")
|
|
4375
|
+
|
|
4376
|
+
try:
|
|
4377
|
+
start_time = time.time()
|
|
4378
|
+
|
|
4379
|
+
# Prepare video generation arguments
|
|
4380
|
+
video_args = {
|
|
4381
|
+
"model": model,
|
|
4382
|
+
"prompt": prompt_data.prompt,
|
|
4383
|
+
}
|
|
4384
|
+
|
|
4385
|
+
if image_for_video:
|
|
4386
|
+
video_args["image"] = image_for_video
|
|
4387
|
+
|
|
4388
|
+
# Create config with all parameters
|
|
4389
|
+
video_config = types.GenerateVideosConfig(
|
|
4390
|
+
aspect_ratio=final_aspect_ratio,
|
|
4391
|
+
number_of_videos=prompt_data.number_of_videos or 1,
|
|
4392
|
+
)
|
|
4393
|
+
|
|
4394
|
+
# Add resolution if supported (check model capabilities)
|
|
4395
|
+
if final_resolution:
|
|
4396
|
+
video_config.resolution = final_resolution
|
|
4397
|
+
|
|
4398
|
+
# Add negative prompt if provided
|
|
4399
|
+
if final_negative_prompt:
|
|
4400
|
+
video_config.negative_prompt = final_negative_prompt
|
|
4401
|
+
|
|
4402
|
+
video_args["config"] = video_config
|
|
4403
|
+
|
|
4404
|
+
# Start async video generation operation
|
|
4405
|
+
self.logger.info("Starting async video generation operation...")
|
|
4406
|
+
operation = await self.client.aio.models.generate_videos(**video_args)
|
|
4407
|
+
|
|
4408
|
+
# Step 3: Poll operation status asynchronously
|
|
4409
|
+
self.logger.info(
|
|
4410
|
+
f"Polling video generation status every {poll_interval} seconds..."
|
|
4411
|
+
)
|
|
4412
|
+
spinner_chars = ['|', '/', '-', '\\']
|
|
4413
|
+
spinner_index = 0
|
|
4414
|
+
poll_count = 0
|
|
4415
|
+
|
|
4416
|
+
# This loop checks the job status every poll_interval seconds
|
|
4417
|
+
while not operation.done:
|
|
4418
|
+
poll_count += 1
|
|
4419
|
+
# This inner loop runs the spinner animation for the poll_interval
|
|
4420
|
+
for _ in range(poll_interval):
|
|
4421
|
+
# Write the spinner character to the console
|
|
4422
|
+
sys.stdout.write(
|
|
4423
|
+
f"\rVideo generation job started. Waiting for completion... {spinner_chars[spinner_index]}"
|
|
4424
|
+
)
|
|
4425
|
+
sys.stdout.flush()
|
|
4426
|
+
spinner_index = (spinner_index + 1) % len(spinner_chars)
|
|
4427
|
+
await asyncio.sleep(1) # Animate every second (async version)
|
|
4428
|
+
|
|
4429
|
+
# After poll_interval seconds, get the updated operation status
|
|
4430
|
+
operation = await self.client.aio.operations.get(operation)
|
|
4431
|
+
|
|
4432
|
+
print("\rVideo generation job completed. ", end="")
|
|
4433
|
+
sys.stdout.flush()
|
|
4434
|
+
|
|
4435
|
+
execution_time = time.time() - start_time
|
|
4436
|
+
self.logger.info(
|
|
4437
|
+
f"Video generation completed in {execution_time:.2f}s after {poll_count} polls"
|
|
4438
|
+
)
|
|
4439
|
+
|
|
4440
|
+
# Step 4: Download and save videos using bytes download
|
|
4441
|
+
generated_videos = operation.response.generated_videos
|
|
4442
|
+
|
|
4443
|
+
if not generated_videos:
|
|
4444
|
+
raise Exception("Video generation completed but no videos were returned")
|
|
4445
|
+
|
|
4446
|
+
saved_video_paths = []
|
|
4447
|
+
raw_response = {'generated_videos': []}
|
|
4448
|
+
|
|
4449
|
+
for n, generated_video in enumerate(generated_videos):
|
|
4450
|
+
# Download the video bytes (MP4)
|
|
4451
|
+
# NOTE: Use sync client for file download as aio may not support it
|
|
4452
|
+
mp4_bytes = self.client.files.download(file=generated_video.video)
|
|
4453
|
+
|
|
4454
|
+
# Save video to file using helper method
|
|
4455
|
+
video_path = self._save_video_file(
|
|
4456
|
+
mp4_bytes,
|
|
4457
|
+
output_directory,
|
|
4458
|
+
video_number=n,
|
|
4459
|
+
mime_format='video/mp4'
|
|
4460
|
+
)
|
|
4461
|
+
saved_video_paths.append(str(video_path))
|
|
4462
|
+
|
|
4463
|
+
self.logger.info(f"Saved video to: {video_path}")
|
|
4464
|
+
|
|
4465
|
+
# Collect metadata
|
|
4466
|
+
raw_response['generated_videos'].append({
|
|
4467
|
+
'path': str(video_path),
|
|
4468
|
+
'duration': getattr(generated_video, 'duration', None),
|
|
4469
|
+
'uri': getattr(generated_video, 'uri', None),
|
|
4470
|
+
})
|
|
4471
|
+
|
|
4472
|
+
# Step 5: Update conversation memory if not stateless
|
|
4473
|
+
usage = CompletionUsage(
|
|
4474
|
+
execution_time=execution_time,
|
|
4475
|
+
# Video API does not return token counts, use approximation
|
|
4476
|
+
input_tokens=len(prompt_data.prompt),
|
|
4477
|
+
)
|
|
4478
|
+
|
|
4479
|
+
if not stateless and conversation_session:
|
|
4480
|
+
await self._update_conversation_memory(
|
|
4481
|
+
user_id,
|
|
4482
|
+
session_id,
|
|
4483
|
+
conversation_session,
|
|
4484
|
+
messages + [
|
|
4485
|
+
{
|
|
4486
|
+
"role": "user",
|
|
4487
|
+
"content": [
|
|
4488
|
+
{"type": "text", "text": f"[Video Generation]: {prompt_data.prompt}"}
|
|
4489
|
+
]
|
|
4490
|
+
},
|
|
4491
|
+
],
|
|
4492
|
+
None,
|
|
4493
|
+
turn_id,
|
|
4494
|
+
prompt_data.prompt,
|
|
4495
|
+
f"Generated {len(saved_video_paths)} video(s)",
|
|
4496
|
+
[]
|
|
4497
|
+
)
|
|
4498
|
+
|
|
4499
|
+
# Step 6: Create and return AIMessage using the factory
|
|
4500
|
+
ai_message = AIMessageFactory.from_video(
|
|
4501
|
+
output=operation, # The raw operation response object
|
|
4502
|
+
files=saved_video_paths, # List of saved video file paths
|
|
4503
|
+
input=prompt_data.prompt,
|
|
4504
|
+
model=model,
|
|
4505
|
+
provider="google_genai",
|
|
4506
|
+
usage=usage,
|
|
4507
|
+
user_id=user_id,
|
|
4508
|
+
session_id=session_id,
|
|
4509
|
+
raw_response=None # Response object isn't easily serializable
|
|
4510
|
+
)
|
|
4511
|
+
|
|
4512
|
+
# Add metadata about the generation
|
|
4513
|
+
ai_message.metadata = {
|
|
4514
|
+
'aspect_ratio': final_aspect_ratio,
|
|
4515
|
+
'resolution': final_resolution,
|
|
4516
|
+
'negative_prompt': final_negative_prompt,
|
|
4517
|
+
'reference_image_used': reference_image is not None or generate_image_first,
|
|
4518
|
+
'image_generation_used': generate_image_first,
|
|
4519
|
+
'poll_count': poll_count,
|
|
4520
|
+
'execution_time': execution_time
|
|
4521
|
+
}
|
|
4522
|
+
|
|
4523
|
+
self.logger.info(
|
|
4524
|
+
f"Video generation successful: {len(saved_video_paths)} video(s) created"
|
|
4525
|
+
)
|
|
4526
|
+
|
|
4527
|
+
return ai_message
|
|
4528
|
+
|
|
4529
|
+
except Exception as e:
|
|
4530
|
+
self.logger.error(f"Video generation failed: {e}", exc_info=True)
|
|
4531
|
+
raise
|
|
4532
|
+
|
|
4533
|
+
def _save_video_file(
|
|
4534
|
+
self,
|
|
4535
|
+
video_bytes: bytes,
|
|
4536
|
+
output_directory: Path,
|
|
4537
|
+
video_number: int = 0,
|
|
4538
|
+
mime_format: str = "video/mp4"
|
|
4539
|
+
) -> Path:
|
|
4540
|
+
"""
|
|
4541
|
+
Helper method to save video bytes to disk.
|
|
4542
|
+
|
|
4543
|
+
Args:
|
|
4544
|
+
video_bytes: Raw video bytes from the API
|
|
4545
|
+
output_directory: Directory to save the video
|
|
4546
|
+
video_number: Index number for the video filename
|
|
4547
|
+
mime_format: MIME type of the video (default: video/mp4)
|
|
4548
|
+
|
|
4549
|
+
Returns:
|
|
4550
|
+
Path to saved video file
|
|
4551
|
+
"""
|
|
4552
|
+
# Generate filename based on timestamp and video number
|
|
4553
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4554
|
+
filename = f"video_{timestamp}_{video_number}.mp4"
|
|
4555
|
+
|
|
4556
|
+
video_path = output_directory / filename
|
|
4557
|
+
|
|
4558
|
+
# Write bytes to file
|
|
4559
|
+
with open(video_path, 'wb') as f:
|
|
4560
|
+
f.write(video_bytes)
|
|
4561
|
+
|
|
4562
|
+
self.logger.info(f"Saved {len(video_bytes)} bytes to {video_path}")
|
|
4563
|
+
|
|
4564
|
+
return video_path
|
|
4565
|
+
|
|
4566
|
+
|
|
4567
|
+
GoogleClient = GoogleGenAIClient # Alias for easier imports
|