ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
parrot/clients/base.py
ADDED
|
@@ -0,0 +1,1491 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import (
|
|
3
|
+
AsyncIterator,
|
|
4
|
+
Dict,
|
|
5
|
+
List,
|
|
6
|
+
Optional,
|
|
7
|
+
Union,
|
|
8
|
+
TypedDict,
|
|
9
|
+
Any,
|
|
10
|
+
Callable
|
|
11
|
+
)
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
import json
|
|
14
|
+
import random
|
|
15
|
+
import re
|
|
16
|
+
import mimetypes
|
|
17
|
+
import asyncio
|
|
18
|
+
import base64
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from dataclasses import dataclass, is_dataclass
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
import io
|
|
23
|
+
import yaml
|
|
24
|
+
from pydantic import (
|
|
25
|
+
BaseModel,
|
|
26
|
+
ValidationError,
|
|
27
|
+
TypeAdapter
|
|
28
|
+
)
|
|
29
|
+
from datamodel.exceptions import ParserError # pylint: disable=E0611 # noqa
|
|
30
|
+
from datamodel.parsers.json import json_decoder, JSONContent # pylint: disable=E0611 # noqa
|
|
31
|
+
import pandas as pd
|
|
32
|
+
import aiohttp
|
|
33
|
+
from navconfig import config
|
|
34
|
+
from navconfig.logging import logging
|
|
35
|
+
from ..memory import (
|
|
36
|
+
ConversationTurn,
|
|
37
|
+
ConversationHistory,
|
|
38
|
+
ConversationMemory,
|
|
39
|
+
InMemoryConversation,
|
|
40
|
+
FileConversationMemory,
|
|
41
|
+
RedisConversation
|
|
42
|
+
)
|
|
43
|
+
from ..tools.pythonrepl import PythonREPLTool
|
|
44
|
+
from ..models import (
|
|
45
|
+
StructuredOutputConfig,
|
|
46
|
+
OutputFormat
|
|
47
|
+
)
|
|
48
|
+
from ..tools.abstract import AbstractTool, ToolResult
|
|
49
|
+
from ..tools.manager import (
|
|
50
|
+
ToolManager,
|
|
51
|
+
ToolFormat,
|
|
52
|
+
ToolDefinition
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
LLM_PRESETS = {
|
|
57
|
+
"analytical": {"temperature": 0.1, "max_tokens": 4000},
|
|
58
|
+
"creative": {"temperature": 0.7, "max_tokens": 6000},
|
|
59
|
+
"balanced": {"temperature": 0.4, "max_tokens": 4000},
|
|
60
|
+
"concise": {"temperature": 0.2, "max_tokens": 2000},
|
|
61
|
+
"detailed": {"temperature": 0.3, "max_tokens": 8000},
|
|
62
|
+
"comprehensive": {"temperature": 0.5, "max_tokens": 10000},
|
|
63
|
+
"verbose": {"temperature": 0.6, "max_tokens": 12000},
|
|
64
|
+
"summarization": {"temperature": 0.2, "max_tokens": 3000},
|
|
65
|
+
"translation": {"temperature": 0.1, "max_tokens": 5000},
|
|
66
|
+
"inspiration": {"temperature": 0.8, "max_tokens": 7000},
|
|
67
|
+
"default": {"temperature": 0.1, "max_tokens": 1024}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def register_python_tool(
|
|
72
|
+
client,
|
|
73
|
+
report_dir: Optional[Path] = None,
|
|
74
|
+
plt_style: str = 'seaborn-v0_8-whitegrid',
|
|
75
|
+
palette: str = 'Set2'
|
|
76
|
+
) -> PythonREPLTool:
|
|
77
|
+
"""Register Python REPL tool with a ClaudeAPIClient.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
client: The ClaudeAPIClient instance
|
|
81
|
+
report_dir: Directory for saving reports
|
|
82
|
+
plt_style: Matplotlib style
|
|
83
|
+
palette: Seaborn color palette
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The PythonREPLTool instance
|
|
87
|
+
"""
|
|
88
|
+
tool = PythonREPLTool(
|
|
89
|
+
report_dir=report_dir,
|
|
90
|
+
plt_style=plt_style,
|
|
91
|
+
palette=palette
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
client.register_tool(
|
|
95
|
+
name="python_repl",
|
|
96
|
+
description=(
|
|
97
|
+
"A Python shell for executing Python commands. "
|
|
98
|
+
"Input should be valid Python code. "
|
|
99
|
+
"Pre-loaded libraries: pandas (pd), numpy (np), matplotlib.pyplot (plt), "
|
|
100
|
+
"seaborn (sns), numexpr (ne). "
|
|
101
|
+
"Available tools: quick_eda, generate_eda_report, list_available_dataframes "
|
|
102
|
+
"from parrot_tools. "
|
|
103
|
+
"Use execution_results dict for capturing intermediate results. "
|
|
104
|
+
"Use report_directory Path for saving outputs. "
|
|
105
|
+
"Use extended_json.dumps(obj)/extended_json.loads(bytes) for JSON operations."
|
|
106
|
+
),
|
|
107
|
+
input_schema=tool.get_tool_schema(),
|
|
108
|
+
function=tool
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return tool
|
|
112
|
+
|
|
113
|
+
class MessageResponse(TypedDict):
|
|
114
|
+
"""Response structure for LLM messages."""
|
|
115
|
+
id: str
|
|
116
|
+
type: str
|
|
117
|
+
role: str
|
|
118
|
+
content: List[Dict[str, Any]]
|
|
119
|
+
model: str
|
|
120
|
+
stop_reason: Optional[str]
|
|
121
|
+
stop_sequence: Optional[str]
|
|
122
|
+
usage: Dict[str, int]
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class RetryConfig:
|
|
126
|
+
"""Configuration for MAX_TOKENS retry behavior."""
|
|
127
|
+
max_retries: int = 1
|
|
128
|
+
token_increase_threshold: int = 1024
|
|
129
|
+
new_token_limit: int = 8192
|
|
130
|
+
error_patterns: List[str] = None
|
|
131
|
+
|
|
132
|
+
def __post_init__(self):
|
|
133
|
+
if self.error_patterns is None:
|
|
134
|
+
self.error_patterns = [
|
|
135
|
+
r"MAX_TOKENS?",
|
|
136
|
+
r"TOKEN.*LIMIT",
|
|
137
|
+
r"CONTEXT.*LENGTH",
|
|
138
|
+
r"TOO.*MANY.*TOKENS"
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
class TokenRetryMixin:
|
|
142
|
+
"""Mixin class to add token retry functionality to any LLM client."""
|
|
143
|
+
|
|
144
|
+
def __init__(self, *args, **kwargs):
|
|
145
|
+
super().__init__(*args, **kwargs)
|
|
146
|
+
self.retry_config = RetryConfig()
|
|
147
|
+
|
|
148
|
+
def is_token_limit_error(self, error: Exception) -> bool:
|
|
149
|
+
"""Check if the error is related to token limits."""
|
|
150
|
+
error_message = str(error).upper()
|
|
151
|
+
|
|
152
|
+
return any(
|
|
153
|
+
re.search(pattern, error_message)
|
|
154
|
+
for pattern in self.retry_config.error_patterns
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def should_retry_with_more_tokens(self, current_tokens: int, retry_count: int) -> bool:
|
|
158
|
+
"""Determine if we should retry with increased tokens."""
|
|
159
|
+
return (
|
|
160
|
+
retry_count < self.retry_config.max_retries and
|
|
161
|
+
current_tokens <= self.retry_config.token_increase_threshold
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def get_increased_token_limit(self, current_tokens: int) -> int:
|
|
165
|
+
"""Calculate the new token limit for retry."""
|
|
166
|
+
if current_tokens <= 1024:
|
|
167
|
+
return 4096
|
|
168
|
+
elif current_tokens <= 4096:
|
|
169
|
+
return 8192
|
|
170
|
+
elif current_tokens <= 8192:
|
|
171
|
+
return 12288
|
|
172
|
+
else:
|
|
173
|
+
return min(current_tokens * 2, 16384) # Cap at 16k tokens
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class BatchRequest:
|
|
177
|
+
"""Data structure for batch request."""
|
|
178
|
+
custom_id: str
|
|
179
|
+
params: Dict[str, Any]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class StreamingRetryConfig:
|
|
183
|
+
"""Configuration for streaming retry behavior."""
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
max_retries: int = 3,
|
|
187
|
+
base_delay: float = 1.0,
|
|
188
|
+
max_delay: float = 60.0,
|
|
189
|
+
backoff_factor: float = 2.0,
|
|
190
|
+
jitter: bool = True,
|
|
191
|
+
auto_retry_on_max_tokens: bool = True,
|
|
192
|
+
token_increase_factor: float = 1.5,
|
|
193
|
+
retry_on_rate_limit: bool = True,
|
|
194
|
+
retry_on_server_error: bool = True
|
|
195
|
+
):
|
|
196
|
+
self.max_retries = max_retries
|
|
197
|
+
self.base_delay = base_delay
|
|
198
|
+
self.max_delay = max_delay
|
|
199
|
+
self.backoff_factor = backoff_factor
|
|
200
|
+
self.jitter = jitter
|
|
201
|
+
self.auto_retry_on_max_tokens = auto_retry_on_max_tokens
|
|
202
|
+
self.token_increase_factor = token_increase_factor
|
|
203
|
+
self.retry_on_rate_limit = retry_on_rate_limit
|
|
204
|
+
self.retry_on_server_error = retry_on_server_error
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class AbstractClient(ABC):
|
|
208
|
+
"""Abstract base Class for LLM models."""
|
|
209
|
+
version: str = "0.1.0"
|
|
210
|
+
base_headers: Dict[str, str] = {
|
|
211
|
+
"Content-Type": "application/json",
|
|
212
|
+
}
|
|
213
|
+
client_type: str = "generic"
|
|
214
|
+
client_name: str = 'generic'
|
|
215
|
+
use_session: bool = False
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
conversation_memory: Optional[ConversationMemory] = None,
|
|
220
|
+
preset: Optional[str] = None,
|
|
221
|
+
tools: Optional[List[Union[str, AbstractTool]]] = None,
|
|
222
|
+
use_tools: bool = False,
|
|
223
|
+
debug: bool = True,
|
|
224
|
+
**kwargs
|
|
225
|
+
):
|
|
226
|
+
self.__name__ = self.__class__.__name__
|
|
227
|
+
self.model: str = kwargs.get('model', None)
|
|
228
|
+
self.client: Any = None
|
|
229
|
+
self.session: Optional[aiohttp.ClientSession] = None
|
|
230
|
+
self.use_session: bool = kwargs.get('use_session', self.use_session)
|
|
231
|
+
if preset:
|
|
232
|
+
preset_config = LLM_PRESETS.get(preset, LLM_PRESETS['default'])
|
|
233
|
+
# define temp, top_k, top_p, max_tokens from selected preset:
|
|
234
|
+
self.temperature = preset_config.get('temperature', 0.4)
|
|
235
|
+
self.top_k = preset_config.get('top_k', 30)
|
|
236
|
+
self.top_p = preset_config.get('top_p', 0.2)
|
|
237
|
+
self.max_tokens = preset_config.get('max_tokens', 4096)
|
|
238
|
+
else:
|
|
239
|
+
# define default values from preset default:
|
|
240
|
+
self.temperature = kwargs.get('temperature', 0)
|
|
241
|
+
self.top_k = kwargs.get('top_k', 30)
|
|
242
|
+
self.top_p = kwargs.get('top_p', 0.2)
|
|
243
|
+
self.max_tokens = kwargs.get('max_tokens', 4096)
|
|
244
|
+
self.conversation_memory = conversation_memory or InMemoryConversation()
|
|
245
|
+
self.base_headers.update(kwargs.get('headers', {}))
|
|
246
|
+
self.api_key = kwargs.get('api_key', None)
|
|
247
|
+
self.version = kwargs.get('version', self.version)
|
|
248
|
+
self._config = config
|
|
249
|
+
self.logger: logging.Logger = logging.getLogger(self.__name__)
|
|
250
|
+
self._json: Any = JSONContent()
|
|
251
|
+
self.client_type: str = kwargs.get('client_type', self.client_type)
|
|
252
|
+
self._debug: bool = debug
|
|
253
|
+
self._program: str = kwargs.get('program', 'parrot') # Default program slug
|
|
254
|
+
# Initialize ToolManager instead of direct tools dict
|
|
255
|
+
self.tool_manager = ToolManager(
|
|
256
|
+
logger=self.logger,
|
|
257
|
+
debug=self._debug
|
|
258
|
+
)
|
|
259
|
+
self.tools: Dict[str, Union[ToolDefinition, AbstractTool]] = {}
|
|
260
|
+
self.enable_tools: bool = use_tools
|
|
261
|
+
# Initialize tools if provided
|
|
262
|
+
if use_tools and tools:
|
|
263
|
+
self.tool_manager.default_tools(tools)
|
|
264
|
+
self.enable_tools = True
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def default_model(self) -> str:
|
|
268
|
+
"""Return the default model for the client."""
|
|
269
|
+
return getattr(self, '_default_model', None)
|
|
270
|
+
|
|
271
|
+
@abstractmethod
|
|
272
|
+
async def get_client(self) -> Any:
|
|
273
|
+
"""Return the client instance."""
|
|
274
|
+
raise NotImplementedError
|
|
275
|
+
|
|
276
|
+
async def __aenter__(self):
|
|
277
|
+
"""Initialize the client context."""
|
|
278
|
+
if self.use_session:
|
|
279
|
+
self.session = aiohttp.ClientSession(
|
|
280
|
+
headers=self.base_headers
|
|
281
|
+
)
|
|
282
|
+
if not self.client:
|
|
283
|
+
self.client = await self.get_client()
|
|
284
|
+
return self
|
|
285
|
+
|
|
286
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
287
|
+
if self.session:
|
|
288
|
+
await self.session.close()
|
|
289
|
+
return False
|
|
290
|
+
|
|
291
|
+
async def close(self):
|
|
292
|
+
if self.client and hasattr(self.client, 'close'):
|
|
293
|
+
await self.client.close()
|
|
294
|
+
|
|
295
|
+
def __repr__(self):
|
|
296
|
+
return f'<{self.__name__} model={self.model} client_type={self.client_type}>'
|
|
297
|
+
|
|
298
|
+
def set_program(self, program_slug: str) -> None:
|
|
299
|
+
"""Set the program slug for the client."""
|
|
300
|
+
self._program = program_slug
|
|
301
|
+
|
|
302
|
+
def _get_chatbot_key(self, chatbot_id: Optional[str] = None) -> Optional[str]:
|
|
303
|
+
"""Resolve chatbot identifier for memory operations."""
|
|
304
|
+
key = chatbot_id or getattr(self, 'chatbot_id', None)
|
|
305
|
+
return None if key is None else str(key)
|
|
306
|
+
|
|
307
|
+
async def start_conversation(
|
|
308
|
+
self,
|
|
309
|
+
user_id: str,
|
|
310
|
+
session_id: str,
|
|
311
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
312
|
+
chatbot_id: Optional[str] = None,
|
|
313
|
+
) -> ConversationHistory:
|
|
314
|
+
"""Start a new conversation session."""
|
|
315
|
+
return await self.conversation_memory.create_history(
|
|
316
|
+
user_id,
|
|
317
|
+
session_id,
|
|
318
|
+
metadata=metadata,
|
|
319
|
+
chatbot_id=self._get_chatbot_key(chatbot_id)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
async def get_conversation(
|
|
323
|
+
self,
|
|
324
|
+
user_id: str,
|
|
325
|
+
session_id: str,
|
|
326
|
+
chatbot_id: Optional[str] = None
|
|
327
|
+
) -> Optional[ConversationHistory]:
|
|
328
|
+
"""Get an existing conversation session."""
|
|
329
|
+
if not self.conversation_memory:
|
|
330
|
+
return None
|
|
331
|
+
return await self.conversation_memory.get_history(
|
|
332
|
+
user_id,
|
|
333
|
+
session_id,
|
|
334
|
+
chatbot_id=self._get_chatbot_key(chatbot_id)
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
async def clear_conversation(
|
|
338
|
+
self,
|
|
339
|
+
user_id: str,
|
|
340
|
+
session_id: str,
|
|
341
|
+
chatbot_id: Optional[str] = None
|
|
342
|
+
) -> bool:
|
|
343
|
+
"""Clear conversation history for a session."""
|
|
344
|
+
if not self.conversation_memory:
|
|
345
|
+
return False
|
|
346
|
+
await self.conversation_memory.clear_history(
|
|
347
|
+
user_id,
|
|
348
|
+
session_id,
|
|
349
|
+
chatbot_id=self._get_chatbot_key(chatbot_id)
|
|
350
|
+
)
|
|
351
|
+
return True
|
|
352
|
+
|
|
353
|
+
async def delete_conversation(
|
|
354
|
+
self,
|
|
355
|
+
user_id: str,
|
|
356
|
+
session_id: str,
|
|
357
|
+
chatbot_id: Optional[str] = None
|
|
358
|
+
) -> bool:
|
|
359
|
+
"""Delete conversation history entirely."""
|
|
360
|
+
if not self.conversation_memory:
|
|
361
|
+
return False
|
|
362
|
+
return await self.conversation_memory.delete_history(
|
|
363
|
+
user_id,
|
|
364
|
+
session_id,
|
|
365
|
+
chatbot_id=self._get_chatbot_key(chatbot_id)
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
async def list_user_conversations(
|
|
369
|
+
self,
|
|
370
|
+
user_id: str,
|
|
371
|
+
chatbot_id: Optional[str] = None
|
|
372
|
+
) -> List[str]:
|
|
373
|
+
"""List all conversation sessions for a user."""
|
|
374
|
+
if not self.conversation_memory:
|
|
375
|
+
return []
|
|
376
|
+
return await self.conversation_memory.list_sessions(
|
|
377
|
+
user_id,
|
|
378
|
+
chatbot_id=self._get_chatbot_key(chatbot_id)
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def set_tools(self, tools: List[Union[str, AbstractTool]]) -> None:
|
|
382
|
+
"""Set complete list of tools, replacing existing."""
|
|
383
|
+
self.tool_manager.clear_tools()
|
|
384
|
+
self.tools.clear()
|
|
385
|
+
self.register_tools(tools)
|
|
386
|
+
|
|
387
|
+
def get_tool(self, name: str) -> Optional[AbstractTool]:
|
|
388
|
+
"""Get a tool by name from ToolManager or legacy tools."""
|
|
389
|
+
# Try ToolManager first
|
|
390
|
+
if tool := self.tool_manager.get_tool(name):
|
|
391
|
+
return tool
|
|
392
|
+
|
|
393
|
+
# Fall back to legacy tools
|
|
394
|
+
legacy_tool = self.tools.get(name)
|
|
395
|
+
return legacy_tool if isinstance(legacy_tool, AbstractTool) else None
|
|
396
|
+
|
|
397
|
+
def register_tool(
|
|
398
|
+
self,
|
|
399
|
+
tool: Union[ToolDefinition, AbstractTool] = None,
|
|
400
|
+
name: str = None,
|
|
401
|
+
description: str = None,
|
|
402
|
+
input_schema: Dict[str, Any] = None,
|
|
403
|
+
function: Callable = None,
|
|
404
|
+
) -> None:
|
|
405
|
+
"""Register a Python function as a tool for LLM to call."""
|
|
406
|
+
self.tool_manager.register_tool(
|
|
407
|
+
tool=tool,
|
|
408
|
+
name=name,
|
|
409
|
+
description=description,
|
|
410
|
+
input_schema=input_schema,
|
|
411
|
+
function=function
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def register_tools(
|
|
415
|
+
self,
|
|
416
|
+
tools: List[Union[ToolDefinition, AbstractTool]]
|
|
417
|
+
) -> None:
|
|
418
|
+
"""Register multiple tools at once."""
|
|
419
|
+
self.tool_manager.register_tools(tools)
|
|
420
|
+
self.enable_tools = True
|
|
421
|
+
|
|
422
|
+
def register_python_tool(
|
|
423
|
+
self,
|
|
424
|
+
report_dir: Optional[Path] = None,
|
|
425
|
+
plt_style: str = 'seaborn-v0_8-whitegrid',
|
|
426
|
+
palette: str = 'Set2'
|
|
427
|
+
) -> PythonREPLTool:
|
|
428
|
+
"""Register Python REPL tool with a ClaudeAPIClient.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
client: The ClaudeAPIClient instance
|
|
432
|
+
report_dir: Directory for saving reports
|
|
433
|
+
plt_style: Matplotlib style
|
|
434
|
+
palette: Seaborn color palette
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
The PythonREPLTool instance
|
|
438
|
+
"""
|
|
439
|
+
if "python_repl" in self.tools:
|
|
440
|
+
return self.tools["python_repl"]
|
|
441
|
+
|
|
442
|
+
tool = PythonREPLTool(
|
|
443
|
+
report_dir=report_dir,
|
|
444
|
+
plt_style=plt_style,
|
|
445
|
+
palette=palette,
|
|
446
|
+
debug=self._debug,
|
|
447
|
+
)
|
|
448
|
+
self.tool_manager.add_tool(tool)
|
|
449
|
+
return tool
|
|
450
|
+
|
|
451
|
+
def list_tools(self) -> List[str]:
|
|
452
|
+
"""Get a list of all registered tool names."""
|
|
453
|
+
tool_names = self.tool_manager.list_tools()
|
|
454
|
+
legacy_names = list(self.tools.keys())
|
|
455
|
+
return tool_names + [name for name in legacy_names if name not in tool_names]
|
|
456
|
+
|
|
457
|
+
def remove_tool(self, name: str) -> bool:
|
|
458
|
+
"""
|
|
459
|
+
Remove a tool by name.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
name: Tool name to remove
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
True if tool was removed, False if not found
|
|
466
|
+
"""
|
|
467
|
+
self.tool_manager.remove_tool(name)
|
|
468
|
+
|
|
469
|
+
def clear_tools(self) -> None:
|
|
470
|
+
"""Clear all registered tools."""
|
|
471
|
+
self.tool_manager.clear_tools()
|
|
472
|
+
self.tools.clear()
|
|
473
|
+
self.logger.info(
|
|
474
|
+
"Cleared all tools"
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
def _encode_file(self, file_path: Union[str, Path]) -> Dict[str, Any]:
|
|
478
|
+
"""Encode file for API upload."""
|
|
479
|
+
path = Path(file_path)
|
|
480
|
+
mime_type, _ = mimetypes.guess_type(str(path))
|
|
481
|
+
|
|
482
|
+
with open(path, "rb") as f:
|
|
483
|
+
encoded = base64.b64encode(f.read()).decode('utf-8')
|
|
484
|
+
|
|
485
|
+
return {
|
|
486
|
+
"type": "document",
|
|
487
|
+
"source": {
|
|
488
|
+
"type": "base64",
|
|
489
|
+
"media_type": mime_type or "application/octet-stream",
|
|
490
|
+
"data": encoded
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
def _make_openai_strict_tool(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
495
|
+
"""
|
|
496
|
+
Ensure the tool schema matches OpenAI strict function-tool requirements:
|
|
497
|
+
- type=function
|
|
498
|
+
- function.strict = True
|
|
499
|
+
- function.parameters is an object schema with additionalProperties = False
|
|
500
|
+
"""
|
|
501
|
+
if schema.get("type") != "function":
|
|
502
|
+
return schema
|
|
503
|
+
|
|
504
|
+
fn = schema.setdefault("function", {})
|
|
505
|
+
params = fn.setdefault("parameters", {})
|
|
506
|
+
|
|
507
|
+
# Ensure base object shape
|
|
508
|
+
if params.get("type") is None:
|
|
509
|
+
params["type"] = "object"
|
|
510
|
+
if "properties" not in params:
|
|
511
|
+
params["properties"] = {}
|
|
512
|
+
|
|
513
|
+
# ✅ NEW: normalize recursively for OpenAI strict rules
|
|
514
|
+
params = self._oai_normalize_schema(params)
|
|
515
|
+
fn["parameters"] = params
|
|
516
|
+
|
|
517
|
+
# Mark strict
|
|
518
|
+
fn["strict"] = True
|
|
519
|
+
return schema
|
|
520
|
+
|
|
521
|
+
self.logger.debug(f"Prepared {len(tool_schemas)} tool schemas")
|
|
522
|
+
return tool_schemas
|
|
523
|
+
|
|
524
|
+
def _check_new_tools(self, tool_name: str, tool_result_content: str) -> List[str]:
|
|
525
|
+
"""
|
|
526
|
+
Check if search_tools was called and return any found tool names.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
tool_name: Name of the executed tool
|
|
530
|
+
tool_result_content: Content returned by the tool (JSON string)
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
List of found tool names
|
|
534
|
+
"""
|
|
535
|
+
if tool_name != "search_tools":
|
|
536
|
+
return []
|
|
537
|
+
|
|
538
|
+
try:
|
|
539
|
+
# Result should be a JSON string of list of dicts
|
|
540
|
+
import json
|
|
541
|
+
found_tools = json.loads(tool_result_content)
|
|
542
|
+
if isinstance(found_tools, list):
|
|
543
|
+
return [t.get("name") for t in found_tools if isinstance(t, dict) and "name" in t]
|
|
544
|
+
except Exception as e:
|
|
545
|
+
self.logger.warning(f"Failed to parse search_tools result: {e}")
|
|
546
|
+
|
|
547
|
+
return []
|
|
548
|
+
|
|
549
|
+
def _prepare_lazy_tools(self, tool_choice: str = "auto") -> List[Dict[str, Any]]:
|
|
550
|
+
"""
|
|
551
|
+
Prepare only the search tool and essential tools for lazy loading.
|
|
552
|
+
"""
|
|
553
|
+
# Always include search_tools
|
|
554
|
+
lazy_tools = ["search_tools"]
|
|
555
|
+
# Maybe include some basics if defined in a preset
|
|
556
|
+
|
|
557
|
+
schemas = []
|
|
558
|
+
for name in lazy_tools:
|
|
559
|
+
if tool := self.tool_manager.get_tool(name):
|
|
560
|
+
# Reuse _prepare_tools logic but for specific tools?
|
|
561
|
+
# _prepare_tools iterates ALL tools in manager.
|
|
562
|
+
# We should probably filter _prepare_tools.
|
|
563
|
+
pass
|
|
564
|
+
|
|
565
|
+
# ACTUALLY, simpler:
|
|
566
|
+
# If lazy loading, we just return the schema for 'search_tools'
|
|
567
|
+
# tool_manager.get_tool_schemas can be updated/used?
|
|
568
|
+
# Or we manually fetch schema for search_tools.
|
|
569
|
+
|
|
570
|
+
# Let's rely on ToolManager.get_tool_schemas supporting filtering?
|
|
571
|
+
# I didn't add filtering to ToolManager.get_tool_schemas yet.
|
|
572
|
+
# I should have done that.
|
|
573
|
+
# But I can just fetch the tool and get its schema.
|
|
574
|
+
|
|
575
|
+
search_tool = self.tool_manager.get_tool("search_tools")
|
|
576
|
+
if not search_tool:
|
|
577
|
+
self.logger.warning("search_tools not found for lazy loading")
|
|
578
|
+
return []
|
|
579
|
+
|
|
580
|
+
# We need to adapt the schema using the same logic as _prepare_tools
|
|
581
|
+
# _prepare_tools calls tool_manager.get_tool_schemas()
|
|
582
|
+
|
|
583
|
+
# I will hack specific getting for now to avoid modifying ToolManager again if possible,
|
|
584
|
+
# but modifying ToolManager to filter is cleaner.
|
|
585
|
+
# Let's assume I can iterate and filter here.
|
|
586
|
+
|
|
587
|
+
return self._prepare_tools(filter_names=["search_tools"])
|
|
588
|
+
|
|
589
|
+
def _prepare_tools(self, filter_names: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
590
|
+
"""Convert registered tools to API format."""
|
|
591
|
+
tool_schemas = []
|
|
592
|
+
processed_tools = set() # Track processed tools to avoid duplicates
|
|
593
|
+
|
|
594
|
+
# Determine the format based on client type
|
|
595
|
+
if self.client_type == 'openai':
|
|
596
|
+
provider_format = ToolFormat.OPENAI
|
|
597
|
+
elif self.client_type == 'google':
|
|
598
|
+
provider_format = ToolFormat.GOOGLE
|
|
599
|
+
elif self.client_type == 'groq':
|
|
600
|
+
provider_format = ToolFormat.GROQ
|
|
601
|
+
elif self.client_type == 'vertex':
|
|
602
|
+
provider_format = ToolFormat.VERTEX
|
|
603
|
+
else:
|
|
604
|
+
provider_format = ToolFormat.ANTHROPIC # Default to Anthropic for Claude
|
|
605
|
+
|
|
606
|
+
# Get tools from ToolManager
|
|
607
|
+
manager_tools = self.tool_manager.get_tool_schemas(provider_format=provider_format)
|
|
608
|
+
|
|
609
|
+
for tool_schema in manager_tools:
|
|
610
|
+
# Remove the _tool_instance for API formatting
|
|
611
|
+
clean_schema = tool_schema.copy()
|
|
612
|
+
clean_schema.pop('_tool_instance', None)
|
|
613
|
+
|
|
614
|
+
tool_name = clean_schema.get('name')
|
|
615
|
+
|
|
616
|
+
# FILTERING LOGIC
|
|
617
|
+
if filter_names is not None and tool_name not in filter_names:
|
|
618
|
+
continue
|
|
619
|
+
|
|
620
|
+
if tool_name and tool_name not in processed_tools:
|
|
621
|
+
# Format according to the client type
|
|
622
|
+
if self.client_type == 'openai':
|
|
623
|
+
# OpenAI expects function wrapper
|
|
624
|
+
formatted_schema = {
|
|
625
|
+
"type": "function",
|
|
626
|
+
"function": {
|
|
627
|
+
"name": clean_schema["name"],
|
|
628
|
+
"description": clean_schema["description"],
|
|
629
|
+
"parameters": clean_schema.get("parameters", {})
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
formatted_schema = self._make_openai_strict_tool(
|
|
633
|
+
formatted_schema
|
|
634
|
+
)
|
|
635
|
+
else:
|
|
636
|
+
# Claude/Anthropic and others use direct format
|
|
637
|
+
formatted_schema = {
|
|
638
|
+
"name": clean_schema["name"],
|
|
639
|
+
"description": clean_schema["description"],
|
|
640
|
+
"input_schema": clean_schema.get("parameters", {})
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
tool_schemas.append(formatted_schema)
|
|
644
|
+
processed_tools.add(tool_name)
|
|
645
|
+
|
|
646
|
+
self.logger.debug(f"Prepared {len(tool_schemas)} tool schemas")
|
|
647
|
+
return tool_schemas
|
|
648
|
+
|
|
649
|
+
async def _execute_tool(
|
|
650
|
+
self,
|
|
651
|
+
tool_name: str,
|
|
652
|
+
parameters: Dict[str, Any]
|
|
653
|
+
) -> Any:
|
|
654
|
+
"""Execute a registered tool function."""
|
|
655
|
+
try:
|
|
656
|
+
result = await self.tool_manager.execute_tool(tool_name, parameters)
|
|
657
|
+
if isinstance(result, ToolResult):
|
|
658
|
+
if result.status == "error":
|
|
659
|
+
raise ValueError(result.error)
|
|
660
|
+
return result.result
|
|
661
|
+
return result
|
|
662
|
+
except Exception as e:
|
|
663
|
+
self.logger.error(
|
|
664
|
+
f"Error executing tool {tool_name}: {e}"
|
|
665
|
+
)
|
|
666
|
+
raise
|
|
667
|
+
|
|
668
|
+
async def _execute_tool_call(
|
|
669
|
+
self,
|
|
670
|
+
content_block: Dict[str, Any]
|
|
671
|
+
) -> Dict[str, Any]:
|
|
672
|
+
"""Execute a single tool call and return the result."""
|
|
673
|
+
tool_name = content_block["name"]
|
|
674
|
+
tool_input = content_block["input"]
|
|
675
|
+
tool_id = content_block["id"]
|
|
676
|
+
|
|
677
|
+
try:
|
|
678
|
+
tool_result = await self._execute_tool(tool_name, tool_input)
|
|
679
|
+
return {
|
|
680
|
+
"type": "tool_result",
|
|
681
|
+
"tool_use_id": tool_id,
|
|
682
|
+
"content": str(tool_result)
|
|
683
|
+
}
|
|
684
|
+
except Exception as e:
|
|
685
|
+
return {
|
|
686
|
+
"type": "tool_result",
|
|
687
|
+
"tool_use_id": tool_id,
|
|
688
|
+
"is_error": True,
|
|
689
|
+
"content": str(e)
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
def _prepare_messages(
|
|
693
|
+
self,
|
|
694
|
+
prompt: str,
|
|
695
|
+
files: Optional[List[Union[str, Path]]] = None
|
|
696
|
+
) -> List[Dict[str, Any]]:
|
|
697
|
+
"""Prepare message content with optional file attachments."""
|
|
698
|
+
content = [{"type": "text", "text": prompt}]
|
|
699
|
+
|
|
700
|
+
if files:
|
|
701
|
+
content.extend(self._encode_file(file_path) for file_path in files)
|
|
702
|
+
|
|
703
|
+
return [{"role": "user", "content": content}]
|
|
704
|
+
|
|
705
|
+
def _validate_response(self, response: Dict[str, Any]) -> bool:
|
|
706
|
+
"""Validate API response structure."""
|
|
707
|
+
required_fields = ["id", "type", "role", "content", "model"]
|
|
708
|
+
return all(field in response for field in required_fields)
|
|
709
|
+
|
|
710
|
+
def _get_structured_config(
|
|
711
|
+
self,
|
|
712
|
+
structured_output: Union[type, StructuredOutputConfig, None]
|
|
713
|
+
) -> Optional[StructuredOutputConfig]:
|
|
714
|
+
"""Get structured output configuration."""
|
|
715
|
+
if isinstance(structured_output, StructuredOutputConfig):
|
|
716
|
+
return structured_output
|
|
717
|
+
elif structured_output:
|
|
718
|
+
return StructuredOutputConfig(
|
|
719
|
+
output_type=structured_output,
|
|
720
|
+
format=OutputFormat.JSON
|
|
721
|
+
)
|
|
722
|
+
return None
|
|
723
|
+
|
|
724
|
+
def _ensure_json_instruction(
|
|
725
|
+
self,
|
|
726
|
+
messages: List[Dict[str, Any]],
|
|
727
|
+
instruction: str
|
|
728
|
+
) -> None:
|
|
729
|
+
"""Ensure the latest user message explicitly requests JSON output."""
|
|
730
|
+
if not instruction:
|
|
731
|
+
return
|
|
732
|
+
|
|
733
|
+
lowered_instruction = instruction.lower()
|
|
734
|
+
|
|
735
|
+
for message in reversed(messages):
|
|
736
|
+
if message.get("role") != "user":
|
|
737
|
+
continue
|
|
738
|
+
|
|
739
|
+
existing_content = message.get("content")
|
|
740
|
+
if isinstance(existing_content, str):
|
|
741
|
+
if lowered_instruction in existing_content.lower():
|
|
742
|
+
return
|
|
743
|
+
message["content"] = [{"type": "text", "text": existing_content}]
|
|
744
|
+
|
|
745
|
+
content = message.setdefault("content", [])
|
|
746
|
+
for block in content:
|
|
747
|
+
if block.get("type") == "text":
|
|
748
|
+
text = block.get("text", "")
|
|
749
|
+
if lowered_instruction in text.lower():
|
|
750
|
+
return
|
|
751
|
+
block["text"] = f"{text}\n\n{instruction}" if text else instruction
|
|
752
|
+
return
|
|
753
|
+
|
|
754
|
+
content.append({"type": "text", "text": instruction})
|
|
755
|
+
return
|
|
756
|
+
|
|
757
|
+
messages.append({
|
|
758
|
+
"role": "user",
|
|
759
|
+
"content": [{"type": "text", "text": instruction}]
|
|
760
|
+
})
|
|
761
|
+
|
|
762
|
+
@abstractmethod
|
|
763
|
+
async def ask(
|
|
764
|
+
self,
|
|
765
|
+
prompt: str,
|
|
766
|
+
model: str,
|
|
767
|
+
max_tokens: int = 4096,
|
|
768
|
+
temperature: float = 0.7,
|
|
769
|
+
files: Optional[List[Union[str, Path]]] = None,
|
|
770
|
+
system_prompt: Optional[str] = None,
|
|
771
|
+
structured_output: Union[type, StructuredOutputConfig, None] = None,
|
|
772
|
+
user_id: Optional[str] = None,
|
|
773
|
+
session_id: Optional[str] = None,
|
|
774
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
775
|
+
use_tools: Optional[bool] = None,
|
|
776
|
+
deep_research: bool = False,
|
|
777
|
+
background: bool = False,
|
|
778
|
+
lazy_loading: bool = False,
|
|
779
|
+
) -> MessageResponse:
|
|
780
|
+
"""Send a prompt to the model and return the response.
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
prompt: The input prompt for the model
|
|
784
|
+
model: The model to use
|
|
785
|
+
max_tokens: Maximum number of tokens in the response
|
|
786
|
+
temperature: Sampling temperature for response generation
|
|
787
|
+
files: Optional files to include in the request
|
|
788
|
+
system_prompt: Optional system prompt to guide the model
|
|
789
|
+
structured_output: Optional structured output configuration
|
|
790
|
+
user_id: Optional user identifier for tracking
|
|
791
|
+
session_id: Optional session identifier for tracking
|
|
792
|
+
tools: Optional tools to register for this call
|
|
793
|
+
use_tools: Whether to use tools
|
|
794
|
+
deep_research: If True, use deep research mode (provider-specific)
|
|
795
|
+
background: If True, execute research in background (async mode)
|
|
796
|
+
lazy_loading: If True, enabled dynamic tool searching
|
|
797
|
+
"""
|
|
798
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
|
799
|
+
|
|
800
|
+
@abstractmethod
|
|
801
|
+
async def ask_stream(
|
|
802
|
+
self,
|
|
803
|
+
prompt: str,
|
|
804
|
+
model: str = None,
|
|
805
|
+
max_tokens: int = 4096,
|
|
806
|
+
temperature: float = 0.7,
|
|
807
|
+
files: Optional[List[Union[str, Path]]] = None,
|
|
808
|
+
system_prompt: Optional[str] = None,
|
|
809
|
+
user_id: Optional[str] = None,
|
|
810
|
+
session_id: Optional[str] = None,
|
|
811
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
812
|
+
deep_research: bool = False,
|
|
813
|
+
agent_config: Optional[Dict[str, Any]] = None,
|
|
814
|
+
lazy_loading: bool = False,
|
|
815
|
+
) -> AsyncIterator[str]:
|
|
816
|
+
"""Stream the model's response.
|
|
817
|
+
|
|
818
|
+
Args:
|
|
819
|
+
prompt: The input prompt for the model
|
|
820
|
+
model: The model to use
|
|
821
|
+
max_tokens: Maximum number of tokens in the response
|
|
822
|
+
temperature: Sampling temperature for response generation
|
|
823
|
+
files: Optional files to include in the request
|
|
824
|
+
system_prompt: Optional system prompt to guide the model
|
|
825
|
+
user_id: Optional user identifier for tracking
|
|
826
|
+
session_id: Optional session identifier for tracking
|
|
827
|
+
tools: Optional tools to register for this call
|
|
828
|
+
deep_research: If True, use deep research mode (provider-specific)
|
|
829
|
+
agent_config: Optional configuration for deep research agent (e.g., thinking_summaries)
|
|
830
|
+
lazy_loading: If True, enabled dynamic tool searching
|
|
831
|
+
"""
|
|
832
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
|
833
|
+
|
|
834
|
+
@abstractmethod
|
|
835
|
+
async def batch_ask(self, requests: List[Any]) -> List[Any]:
|
|
836
|
+
"""Process multiple requests in batch."""
|
|
837
|
+
raise NotImplementedError("Subclasses must implement batch processing.")
|
|
838
|
+
|
|
839
|
+
async def _handle_structured_output(
|
|
840
|
+
self,
|
|
841
|
+
result: Dict[str, Any],
|
|
842
|
+
structured_output: Optional[type]
|
|
843
|
+
) -> Any:
|
|
844
|
+
"""Parse response into structured output format."""
|
|
845
|
+
if not structured_output:
|
|
846
|
+
return result
|
|
847
|
+
|
|
848
|
+
text_content = "".join(
|
|
849
|
+
content_block["text"]
|
|
850
|
+
for content_block in result["content"]
|
|
851
|
+
if content_block["type"] == "text"
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
try:
|
|
855
|
+
if not hasattr(structured_output, '__annotations__'):
|
|
856
|
+
return structured_output(text_content)
|
|
857
|
+
parsed = json_decoder(text_content)
|
|
858
|
+
return self._coerce_mapping_to_type(structured_output, parsed)
|
|
859
|
+
except Exception: # pylint: disable=broad-except
|
|
860
|
+
return result
|
|
861
|
+
|
|
862
|
+
def _coerce_mapping_to_type(self, output_type: type, data: Any) -> Any:
|
|
863
|
+
"""Attempt to instantiate output_type from mapping-like data."""
|
|
864
|
+
if data is None:
|
|
865
|
+
return None
|
|
866
|
+
|
|
867
|
+
if is_dataclass(output_type):
|
|
868
|
+
try:
|
|
869
|
+
if isinstance(data, list):
|
|
870
|
+
return [self._coerce_mapping_to_type(output_type, item) for item in data]
|
|
871
|
+
if isinstance(data, dict):
|
|
872
|
+
return output_type(**data)
|
|
873
|
+
except TypeError:
|
|
874
|
+
return data
|
|
875
|
+
return data
|
|
876
|
+
|
|
877
|
+
if hasattr(output_type, '__annotations__'):
|
|
878
|
+
if isinstance(data, list):
|
|
879
|
+
coerced = []
|
|
880
|
+
for item in data:
|
|
881
|
+
if isinstance(item, dict):
|
|
882
|
+
try:
|
|
883
|
+
coerced.append(output_type(**item))
|
|
884
|
+
except TypeError:
|
|
885
|
+
coerced.append(item)
|
|
886
|
+
else:
|
|
887
|
+
coerced.append(item)
|
|
888
|
+
return coerced
|
|
889
|
+
if isinstance(data, dict):
|
|
890
|
+
try:
|
|
891
|
+
return output_type(**data)
|
|
892
|
+
except TypeError:
|
|
893
|
+
return data
|
|
894
|
+
|
|
895
|
+
return data
|
|
896
|
+
|
|
897
|
+
async def _process_tool_calls(
|
|
898
|
+
self,
|
|
899
|
+
initial_result: Dict[str, Any],
|
|
900
|
+
messages: List[Dict[str, Any]],
|
|
901
|
+
payload: Dict[str, Any],
|
|
902
|
+
endpoint: str
|
|
903
|
+
) -> Dict[str, Any]:
|
|
904
|
+
"""Handle tool calls in a loop until completion."""
|
|
905
|
+
result = initial_result
|
|
906
|
+
|
|
907
|
+
while result.get("stop_reason") == "tool_use":
|
|
908
|
+
tool_results = []
|
|
909
|
+
|
|
910
|
+
for content_block in result["content"]:
|
|
911
|
+
if content_block["type"] == "tool_use":
|
|
912
|
+
tool_result = await self._execute_tool_call(content_block)
|
|
913
|
+
tool_results.append(tool_result)
|
|
914
|
+
|
|
915
|
+
messages.append({"role": "assistant", "content": result["content"]})
|
|
916
|
+
messages.append({"role": "user", "content": tool_results})
|
|
917
|
+
payload["messages"] = messages
|
|
918
|
+
|
|
919
|
+
async with self.session.post(endpoint, json=payload) as response:
|
|
920
|
+
response.raise_for_status()
|
|
921
|
+
result = await response.json()
|
|
922
|
+
|
|
923
|
+
# Add final assistant response
|
|
924
|
+
messages.append({"role": "assistant", "content": result["content"]})
|
|
925
|
+
return result
|
|
926
|
+
|
|
927
|
+
async def _prepare_conversation_context(
|
|
928
|
+
self,
|
|
929
|
+
prompt: str,
|
|
930
|
+
files: Optional[List[Union[str, Path]]],
|
|
931
|
+
user_id: Optional[str],
|
|
932
|
+
session_id: Optional[str],
|
|
933
|
+
system_prompt: Optional[str],
|
|
934
|
+
stateless: bool = False
|
|
935
|
+
) -> tuple[List[Dict[str, Any]], Optional[ConversationHistory], Optional[str]]:
|
|
936
|
+
"""Prepare conversation context and return messages, session, and system prompt."""
|
|
937
|
+
messages = []
|
|
938
|
+
conversation_history = None
|
|
939
|
+
|
|
940
|
+
if user_id and session_id:
|
|
941
|
+
conversation_history = await self.conversation_memory.get_history(
|
|
942
|
+
user_id,
|
|
943
|
+
session_id,
|
|
944
|
+
chatbot_id=self._get_chatbot_key()
|
|
945
|
+
)
|
|
946
|
+
if not conversation_history:
|
|
947
|
+
conversation_history = await self.conversation_memory.create_history(
|
|
948
|
+
user_id,
|
|
949
|
+
session_id,
|
|
950
|
+
chatbot_id=self._get_chatbot_key()
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
# Get recent conversation messages for context
|
|
954
|
+
if conversation_history:
|
|
955
|
+
messages = conversation_history.get_messages_for_api()
|
|
956
|
+
new_user_message = self._prepare_messages(prompt, files)[0]
|
|
957
|
+
messages.append(new_user_message)
|
|
958
|
+
|
|
959
|
+
# Convert stored conversation turns to messages format and create system prompt:
|
|
960
|
+
if conversation_history and not stateless:
|
|
961
|
+
self.logger.debug(
|
|
962
|
+
f"Found {len(conversation_history.turns)} previous turns"
|
|
963
|
+
)
|
|
964
|
+
for turn in conversation_history.turns:
|
|
965
|
+
# Add user message
|
|
966
|
+
messages.append({
|
|
967
|
+
"role": "user",
|
|
968
|
+
"content": [{"type": "text", "text": turn.user_message}]
|
|
969
|
+
})
|
|
970
|
+
|
|
971
|
+
# Add assistant message
|
|
972
|
+
messages.append({
|
|
973
|
+
"role": "assistant",
|
|
974
|
+
"content": [{"type": "text", "text": turn.assistant_response}]
|
|
975
|
+
})
|
|
976
|
+
|
|
977
|
+
if not system_prompt and len(conversation_history.turns) > 0:
|
|
978
|
+
# Create a summary of the conversation context
|
|
979
|
+
recent_context = []
|
|
980
|
+
for turn in conversation_history.turns[-3:]: # Last 3 turns for context
|
|
981
|
+
recent_context.extend(
|
|
982
|
+
(
|
|
983
|
+
f"User: {turn.user_message}",
|
|
984
|
+
f"Assistant: {turn.assistant_response}",
|
|
985
|
+
)
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
recent = "\n".join(recent_context)
|
|
989
|
+
system_prompt = (
|
|
990
|
+
"You are a helpful AI assistant. You have access to the following conversation history:\n\n"
|
|
991
|
+
f"{recent}"
|
|
992
|
+
"\n\nUse this context to provide relevant and consistent responses. "
|
|
993
|
+
"When users refer to previously mentioned information, acknowledge and use that context."
|
|
994
|
+
)
|
|
995
|
+
self.logger.debug("Created contextual system prompt from conversation history")
|
|
996
|
+
|
|
997
|
+
# Handle file attachments if provided
|
|
998
|
+
current_message_parts = [{"type": "text", "text": prompt}]
|
|
999
|
+
if files:
|
|
1000
|
+
for file_path in files:
|
|
1001
|
+
try:
|
|
1002
|
+
file_path = Path(file_path)
|
|
1003
|
+
if file_path.exists():
|
|
1004
|
+
current_message_parts.append({
|
|
1005
|
+
"type": "file",
|
|
1006
|
+
"file_path": str(file_path)
|
|
1007
|
+
})
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
self.logger.error(f"Error processing file {file_path}: {e}")
|
|
1010
|
+
|
|
1011
|
+
# Add the current user message
|
|
1012
|
+
messages.append({
|
|
1013
|
+
"role": "user",
|
|
1014
|
+
"content": current_message_parts
|
|
1015
|
+
})
|
|
1016
|
+
|
|
1017
|
+
# self.logger.debug(f"Prepared {len(messages)} messages for conversation context")
|
|
1018
|
+
return messages, conversation_history, system_prompt
|
|
1019
|
+
|
|
1020
|
+
async def _update_conversation_memory(
|
|
1021
|
+
self,
|
|
1022
|
+
user_id: Optional[str],
|
|
1023
|
+
session_id: Optional[str],
|
|
1024
|
+
conversation_history: Optional[ConversationHistory],
|
|
1025
|
+
messages: List[Dict[str, Any]],
|
|
1026
|
+
system_prompt: Optional[str],
|
|
1027
|
+
turn_id: str,
|
|
1028
|
+
original_prompt: str,
|
|
1029
|
+
assistant_response: str,
|
|
1030
|
+
tools_used: List[str] = None
|
|
1031
|
+
) -> None:
|
|
1032
|
+
"""Update conversation memory with the latest turn."""
|
|
1033
|
+
if not (user_id and session_id and conversation_history and self.conversation_memory):
|
|
1034
|
+
return
|
|
1035
|
+
|
|
1036
|
+
# Create a new conversation turn
|
|
1037
|
+
turn = ConversationTurn(
|
|
1038
|
+
turn_id=turn_id,
|
|
1039
|
+
user_id=user_id,
|
|
1040
|
+
user_message=original_prompt,
|
|
1041
|
+
assistant_response=assistant_response,
|
|
1042
|
+
context_used=system_prompt,
|
|
1043
|
+
tools_used=tools_used or [],
|
|
1044
|
+
metadata={
|
|
1045
|
+
"message_count": len(messages),
|
|
1046
|
+
"has_system_prompt": bool(system_prompt),
|
|
1047
|
+
"provider": getattr(self, 'client_type', 'unknown')
|
|
1048
|
+
}
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
# Add turn to conversation history
|
|
1052
|
+
await self.conversation_memory.add_turn(
|
|
1053
|
+
user_id,
|
|
1054
|
+
session_id,
|
|
1055
|
+
turn,
|
|
1056
|
+
chatbot_id=self._get_chatbot_key()
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
def _extract_json_from_response(self, text: str) -> str:
|
|
1060
|
+
"""Extract JSON from Claude's response, handling markdown code blocks and extra text."""
|
|
1061
|
+
# First, try to find JSON in markdown code blocks
|
|
1062
|
+
json_pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
|
|
1063
|
+
match = re.search(json_pattern, text, re.DOTALL)
|
|
1064
|
+
if match:
|
|
1065
|
+
return match.group(1).strip()
|
|
1066
|
+
|
|
1067
|
+
# Try to find JSON object in the text (looking for { ... })
|
|
1068
|
+
json_object_pattern = r'\{.*\}'
|
|
1069
|
+
match = re.search(json_object_pattern, text, re.DOTALL)
|
|
1070
|
+
if match:
|
|
1071
|
+
return match.group(0).strip()
|
|
1072
|
+
|
|
1073
|
+
# Try to find JSON array in the text (looking for [ ... ])
|
|
1074
|
+
json_array_pattern = r'\[.*\]'
|
|
1075
|
+
match = re.search(json_array_pattern, text, re.DOTALL)
|
|
1076
|
+
if match:
|
|
1077
|
+
return match.group(0).strip()
|
|
1078
|
+
|
|
1079
|
+
# If no JSON found, return the original text
|
|
1080
|
+
return text.strip()
|
|
1081
|
+
|
|
1082
|
+
def _unwrap_nested_response(self, parsed_json: Any, output_type: type) -> Any:
|
|
1083
|
+
"""Unwrap JSON responses that are nested under a single key.
|
|
1084
|
+
|
|
1085
|
+
Some LLMs (especially Claude) wrap their response in an extra key layer.
|
|
1086
|
+
For example: {"dinner_plan": {"appetizer": "...", ...}}
|
|
1087
|
+
instead of: {"appetizer": "...", ...}
|
|
1088
|
+
|
|
1089
|
+
This method detects and unwraps such responses.
|
|
1090
|
+
"""
|
|
1091
|
+
if not isinstance(parsed_json, dict):
|
|
1092
|
+
return parsed_json
|
|
1093
|
+
|
|
1094
|
+
# If the JSON has exactly one key and it's a dict, check if unwrapping makes sense
|
|
1095
|
+
if len(parsed_json) == 1:
|
|
1096
|
+
single_key = list(parsed_json.keys())[0]
|
|
1097
|
+
nested_value = parsed_json[single_key]
|
|
1098
|
+
|
|
1099
|
+
# Only unwrap if the nested value is a dict
|
|
1100
|
+
if isinstance(nested_value, dict):
|
|
1101
|
+
# Try to validate the nested value against the expected type
|
|
1102
|
+
if hasattr(output_type, 'model_validate'):
|
|
1103
|
+
try:
|
|
1104
|
+
# If this succeeds, the nested value is the correct structure
|
|
1105
|
+
output_type.model_validate(nested_value)
|
|
1106
|
+
return nested_value
|
|
1107
|
+
except (ValidationError, Exception):
|
|
1108
|
+
# If validation fails, return original
|
|
1109
|
+
pass
|
|
1110
|
+
elif hasattr(output_type, '__annotations__'):
|
|
1111
|
+
# For dataclasses, check if fields match
|
|
1112
|
+
expected_fields = set(output_type.__annotations__.keys())
|
|
1113
|
+
nested_fields = set(nested_value.keys())
|
|
1114
|
+
|
|
1115
|
+
# If nested value has the expected fields, unwrap it
|
|
1116
|
+
if expected_fields & nested_fields: # If there's any overlap
|
|
1117
|
+
return nested_value
|
|
1118
|
+
|
|
1119
|
+
return parsed_json
|
|
1120
|
+
|
|
1121
|
+
async def _parse_structured_output( # noqa: C901
|
|
1122
|
+
self,
|
|
1123
|
+
response_text: str,
|
|
1124
|
+
structured_output: StructuredOutputConfig
|
|
1125
|
+
) -> Any:
|
|
1126
|
+
"""Parse structured output based on format."""
|
|
1127
|
+
try:
|
|
1128
|
+
output_type = structured_output.output_type
|
|
1129
|
+
if not output_type:
|
|
1130
|
+
raise ValueError(
|
|
1131
|
+
"Output type is not specified in structured output config."
|
|
1132
|
+
)
|
|
1133
|
+
# default to JSON parsing if no specific schema is provided
|
|
1134
|
+
if structured_output.format == OutputFormat.JSON:
|
|
1135
|
+
# Current JSON logic
|
|
1136
|
+
try:
|
|
1137
|
+
# first, try to remove backsticks (markdown code blocks) if any:
|
|
1138
|
+
# This is the right way to do it.
|
|
1139
|
+
response_text = response_text.strip()
|
|
1140
|
+
if response_text.startswith('```json'):
|
|
1141
|
+
response_text = response_text[7:-3]
|
|
1142
|
+
if hasattr(output_type, 'model_validate_json') or hasattr(output_type, 'model_validate'):
|
|
1143
|
+
# For model_validate_json, we need to parse first to unwrap
|
|
1144
|
+
if not isinstance(output_type, type):
|
|
1145
|
+
output_type = output_type.__class__
|
|
1146
|
+
parsed_json = self._json.loads(response_text)
|
|
1147
|
+
parsed_json = self._unwrap_nested_response(parsed_json, output_type)
|
|
1148
|
+
return output_type.model_validate(parsed_json)
|
|
1149
|
+
else:
|
|
1150
|
+
parsed_json = self._json.loads(response_text)
|
|
1151
|
+
parsed_json = self._unwrap_nested_response(parsed_json, output_type)
|
|
1152
|
+
if is_dataclass(output_type) or hasattr(output_type, '__annotations__'):
|
|
1153
|
+
return self._coerce_mapping_to_type(output_type, parsed_json)
|
|
1154
|
+
return parsed_json
|
|
1155
|
+
except (ParserError, ValidationError, json.JSONDecodeError) as e:
|
|
1156
|
+
self.logger.warning(f"Standard parsing failed: {e}")
|
|
1157
|
+
try:
|
|
1158
|
+
# Try fallback with field mapping
|
|
1159
|
+
json_text = self._extract_json_from_response(response_text)
|
|
1160
|
+
parsed_json = self._json.loads(json_text)
|
|
1161
|
+
parsed_json = self._unwrap_nested_response(parsed_json, output_type)
|
|
1162
|
+
if hasattr(output_type, 'model_validate'):
|
|
1163
|
+
return output_type.model_validate(parsed_json)
|
|
1164
|
+
if is_dataclass(output_type) or hasattr(output_type, '__annotations__'):
|
|
1165
|
+
return self._coerce_mapping_to_type(output_type, parsed_json)
|
|
1166
|
+
return parsed_json
|
|
1167
|
+
except (ParserError, ValidationError, json.JSONDecodeError) as e:
|
|
1168
|
+
self.logger.warning(
|
|
1169
|
+
f"Fallback parsing failed: {e}"
|
|
1170
|
+
)
|
|
1171
|
+
return response_text
|
|
1172
|
+
elif structured_output.format == OutputFormat.TEXT:
|
|
1173
|
+
# Parse natural language text into structured format
|
|
1174
|
+
return await self._parse_text_to_structure(
|
|
1175
|
+
response_text,
|
|
1176
|
+
output_type
|
|
1177
|
+
)
|
|
1178
|
+
elif structured_output.format == OutputFormat.CSV:
|
|
1179
|
+
df = pd.read_csv(io.StringIO(response_text))
|
|
1180
|
+
return df if output_type == pd.DataFrame else df
|
|
1181
|
+
elif structured_output.format == OutputFormat.YAML:
|
|
1182
|
+
data = yaml.safe_load(response_text)
|
|
1183
|
+
if hasattr(output_type, 'model_validate'):
|
|
1184
|
+
return output_type.model_validate(data)
|
|
1185
|
+
if is_dataclass(output_type) or hasattr(output_type, '__annotations__'):
|
|
1186
|
+
return self._coerce_mapping_to_type(output_type, data)
|
|
1187
|
+
return data
|
|
1188
|
+
elif structured_output.format == OutputFormat.CUSTOM:
|
|
1189
|
+
if structured_output.custom_parser:
|
|
1190
|
+
return structured_output.custom_parser(response_text)
|
|
1191
|
+
else:
|
|
1192
|
+
raise ValueError(
|
|
1193
|
+
f"Unsupported output format: {structured_output.format}"
|
|
1194
|
+
)
|
|
1195
|
+
except (ParserError, ValueError) as exc:
|
|
1196
|
+
self.logger.error(f"Error parsing structured output: {exc}")
|
|
1197
|
+
# Fallback to raw text if parsing fails
|
|
1198
|
+
return response_text
|
|
1199
|
+
except Exception as exc:
|
|
1200
|
+
self.logger.error(
|
|
1201
|
+
f"Unexpected error during structured output parsing: {exc}"
|
|
1202
|
+
)
|
|
1203
|
+
# Fallback to raw text
|
|
1204
|
+
return response_text
|
|
1205
|
+
|
|
1206
|
+
async def _parse_text_to_structure(self, text: str, output_type: type) -> Any:
|
|
1207
|
+
"""Parse natural language text into a structured format using AI."""
|
|
1208
|
+
# Option 1: Use regex/NLP parsing for simple cases
|
|
1209
|
+
if hasattr(output_type, '__annotations__'):
|
|
1210
|
+
annotations = output_type.__annotations__
|
|
1211
|
+
|
|
1212
|
+
# Simple extraction for common patterns
|
|
1213
|
+
if 'addition_result' in annotations and 'multiplication_result' in annotations:
|
|
1214
|
+
|
|
1215
|
+
# Extract numbers from text like "12 + 8 = 20" and "6 * 9 = 54"
|
|
1216
|
+
addition_match = re.search(r'(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)', text)
|
|
1217
|
+
multiplication_match = re.search(r'(\d+)\s*\*\s*(\d+)\s*=\s*(\d+)', text)
|
|
1218
|
+
|
|
1219
|
+
data = {
|
|
1220
|
+
'addition_result': float(addition_match.group(3)) if addition_match else 0.0,
|
|
1221
|
+
'multiplication_result': float(
|
|
1222
|
+
multiplication_match.group(3)
|
|
1223
|
+
) if multiplication_match else 0.0,
|
|
1224
|
+
'explanation': text
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
return output_type(**data)
|
|
1228
|
+
|
|
1229
|
+
# Fallback: return text if parsing fails
|
|
1230
|
+
return text
|
|
1231
|
+
|
|
1232
|
+
def _save_image(
|
|
1233
|
+
self,
|
|
1234
|
+
image: Any,
|
|
1235
|
+
output_directory: Path,
|
|
1236
|
+
prefix: str = 'generated_image_'
|
|
1237
|
+
) -> Path:
|
|
1238
|
+
"""Save a PIL image to the specified directory."""
|
|
1239
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
1240
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1241
|
+
file_path = output_directory / f"{prefix}{timestamp}.jpeg"
|
|
1242
|
+
image.save(file_path)
|
|
1243
|
+
self.logger.info(f"Saved image to {file_path}")
|
|
1244
|
+
return file_path
|
|
1245
|
+
|
|
1246
|
+
def _save_audio_file(self, audio_data: bytes, output_path: Path, mime_format: str):
|
|
1247
|
+
"""
|
|
1248
|
+
Saves the audio data to a file in the specified format.
|
|
1249
|
+
"""
|
|
1250
|
+
from pydub import AudioSegment # pylint: disable=C0415 # noqa
|
|
1251
|
+
import wave # pylint: disable=C0415 # noqa
|
|
1252
|
+
if mime_format == "audio/wav":
|
|
1253
|
+
# Save as WAV using the wave module
|
|
1254
|
+
output_path = output_path.with_suffix('.wav')
|
|
1255
|
+
with wave.open(str(output_path), mode="wb") as wf:
|
|
1256
|
+
# Mono
|
|
1257
|
+
wf.setnchannels(1) # pylint: disable=E1101 # noqa
|
|
1258
|
+
# 16-bit PCM
|
|
1259
|
+
wf.setsampwidth(2) # pylint: disable=E1101 # noqa
|
|
1260
|
+
wf.setcomptype("NONE", "not compressed") # pylint: disable=E1101 # noqa
|
|
1261
|
+
# 24kHz sample rate
|
|
1262
|
+
wf.setframerate(24000) # pylint: disable=E1101 # noqa
|
|
1263
|
+
wf.writeframes(audio_data) # pylint: disable=E1101 # noqa
|
|
1264
|
+
elif mime_format in ("audio/mpeg", "audio/webm"):
|
|
1265
|
+
# choose extension and pydub format name
|
|
1266
|
+
ext = "mp3" if mime_format == "audio/mpeg" else "webm"
|
|
1267
|
+
fp = output_path.with_suffix(f'.{ext}')
|
|
1268
|
+
|
|
1269
|
+
# wrap raw PCM bytes in a BytesIO so pydub can read them
|
|
1270
|
+
raw = io.BytesIO(audio_data)
|
|
1271
|
+
seg = AudioSegment.from_raw(
|
|
1272
|
+
raw,
|
|
1273
|
+
sample_width=2,
|
|
1274
|
+
frame_rate=24000,
|
|
1275
|
+
channels=1
|
|
1276
|
+
)
|
|
1277
|
+
# export using the appropriate container/codec
|
|
1278
|
+
seg.export(str(fp), format=ext)
|
|
1279
|
+
|
|
1280
|
+
else:
|
|
1281
|
+
raise ValueError(f"Unsupported mime_format: {mime_format!r}")
|
|
1282
|
+
|
|
1283
|
+
def _save_video_file(
|
|
1284
|
+
self,
|
|
1285
|
+
mp4_bytes,
|
|
1286
|
+
output_dir: Path,
|
|
1287
|
+
video_number: int = 1,
|
|
1288
|
+
mime_format: str = 'video/mp4',
|
|
1289
|
+
prefix: str = 'generated_video_'
|
|
1290
|
+
) -> Path:
|
|
1291
|
+
"""
|
|
1292
|
+
Download the GenAI video (always MP4), then either:
|
|
1293
|
+
- Write it straight out if mime_format is video/mp4
|
|
1294
|
+
- Otherwise, transcode via ffmpeg to the requested container/codec
|
|
1295
|
+
Returns the Path to the saved file.
|
|
1296
|
+
|
|
1297
|
+
"""
|
|
1298
|
+
import ffmpeg # pylint: disable=C0415 # noqa
|
|
1299
|
+
# 1) Prep output path
|
|
1300
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1301
|
+
ext = mimetypes.guess_extension(mime_format) or '.mp4'
|
|
1302
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1303
|
+
out_path = output_dir / f"{prefix}{timestamp}_{video_number}{ext}"
|
|
1304
|
+
|
|
1305
|
+
# 3) Straight-dump for MP4
|
|
1306
|
+
if mime_format == "video/mp4":
|
|
1307
|
+
out_path.write_bytes(mp4_bytes)
|
|
1308
|
+
self.logger.info(
|
|
1309
|
+
f"Saved MP4 to {out_path}"
|
|
1310
|
+
)
|
|
1311
|
+
return out_path
|
|
1312
|
+
|
|
1313
|
+
# 4) Transcode via ffmpeg for other formats
|
|
1314
|
+
try:
|
|
1315
|
+
if mime_format == 'video/avi':
|
|
1316
|
+
video_format = 'avi'
|
|
1317
|
+
vcodec = 'libxvid' # H.264 codec for AVI
|
|
1318
|
+
acodec = 'mp2' # MP2 audio codec for AVI
|
|
1319
|
+
elif mime_format == 'video/webm':
|
|
1320
|
+
video_format = 'webm'
|
|
1321
|
+
vcodec = 'libvpx' # VP8 video codec for WebM
|
|
1322
|
+
acodec = 'libopus'
|
|
1323
|
+
elif mime_format == 'video/mpeg':
|
|
1324
|
+
video_format = 'mpeg'
|
|
1325
|
+
vcodec = 'mpeg2video' # MPEG-2 video codec
|
|
1326
|
+
acodec = 'mp2' # MP2 audio codec
|
|
1327
|
+
else:
|
|
1328
|
+
raise ValueError(
|
|
1329
|
+
f"Unsupported mime_format for video transcoding: {mime_format!r}"
|
|
1330
|
+
)
|
|
1331
|
+
# 1. Set up the FFmpeg process
|
|
1332
|
+
process = (
|
|
1333
|
+
ffmpeg # pylint: disable=E1101 # noqa
|
|
1334
|
+
.input('pipe:', format='mp4') # pylint: disable=E1101 # noqa
|
|
1335
|
+
.output(
|
|
1336
|
+
'pipe:',
|
|
1337
|
+
format=video_format, # Output container format
|
|
1338
|
+
vcodec=vcodec, # video codec
|
|
1339
|
+
acodec=acodec # audio codec
|
|
1340
|
+
)
|
|
1341
|
+
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
|
|
1342
|
+
)
|
|
1343
|
+
# 2. Pipe the mp4 bytes in and get the webm bytes out
|
|
1344
|
+
out_bytes, err = process.communicate(input=mp4_bytes)
|
|
1345
|
+
process.wait()
|
|
1346
|
+
if err:
|
|
1347
|
+
self.logger.error("FFmpeg Error:", err.decode())
|
|
1348
|
+
with open(out_path, 'wb') as f:
|
|
1349
|
+
f.write(out_bytes)
|
|
1350
|
+
self.logger.info(
|
|
1351
|
+
f"Saved {mime_format} to {out_path}"
|
|
1352
|
+
)
|
|
1353
|
+
return out_path
|
|
1354
|
+
except Exception as e:
|
|
1355
|
+
self.logger.error(
|
|
1356
|
+
f"Error saving {mime_format} to {out_path}: {e}"
|
|
1357
|
+
)
|
|
1358
|
+
return None
|
|
1359
|
+
|
|
1360
|
+
@staticmethod
|
|
1361
|
+
def create_conversation_memory(
|
|
1362
|
+
memory_type: str = "memory",
|
|
1363
|
+
**kwargs
|
|
1364
|
+
) -> ConversationMemory:
|
|
1365
|
+
"""Factory method to create a conversation memory instance."""
|
|
1366
|
+
if memory_type == "memory":
|
|
1367
|
+
return InMemoryConversation()
|
|
1368
|
+
elif memory_type == "redis":
|
|
1369
|
+
return RedisConversation(**kwargs)
|
|
1370
|
+
elif memory_type == "file":
|
|
1371
|
+
return FileConversationMemory(**kwargs)
|
|
1372
|
+
else:
|
|
1373
|
+
raise ValueError(
|
|
1374
|
+
f"Unsupported memory type: {memory_type}"
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
async def _wait_with_backoff(self, retry_count: int, config: StreamingRetryConfig) -> None:
|
|
1378
|
+
"""Wait with exponential backoff before retry."""
|
|
1379
|
+
delay = min(
|
|
1380
|
+
config.base_delay * (config.backoff_factor ** (retry_count - 1)),
|
|
1381
|
+
config.max_delay
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
if config.jitter:
|
|
1385
|
+
# Add random jitter to avoid thundering herd
|
|
1386
|
+
delay *= (0.5 + random.random() * 0.5)
|
|
1387
|
+
|
|
1388
|
+
await asyncio.sleep(delay)
|
|
1389
|
+
|
|
1390
|
+
def _parse_json_from_text(self, text: str) -> Union[dict, list]:
|
|
1391
|
+
"""Robustly parse JSON even if the model wraps it in ```json fences."""
|
|
1392
|
+
if not text:
|
|
1393
|
+
return {}
|
|
1394
|
+
# strip fences
|
|
1395
|
+
s = text.strip()
|
|
1396
|
+
s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.I)
|
|
1397
|
+
s = re.sub(r"\s*```$", "", s)
|
|
1398
|
+
# grab the largest {...} or [...] block if extra prose sneaks in
|
|
1399
|
+
m = re.search(r"(\{.*\}|\[.*\])", s, flags=re.S)
|
|
1400
|
+
s = m[1] if m else s
|
|
1401
|
+
return json_decoder(s)
|
|
1402
|
+
|
|
1403
|
+
def _oai_normalize_schema(self, schema: dict, *, force_required_all: bool = True) -> dict:
|
|
1404
|
+
"""
|
|
1405
|
+
Normalize JSON schema.
|
|
1406
|
+
- Always sets additionalProperties=false on objects.
|
|
1407
|
+
- Optionally forces required to include all properties.
|
|
1408
|
+
"""
|
|
1409
|
+
def visit(node):
|
|
1410
|
+
if isinstance(node, dict):
|
|
1411
|
+
t = node.get("type")
|
|
1412
|
+
|
|
1413
|
+
if t == "object":
|
|
1414
|
+
node["additionalProperties"] = False
|
|
1415
|
+
|
|
1416
|
+
if force_required_all:
|
|
1417
|
+
props = node.get("properties")
|
|
1418
|
+
if isinstance(props, dict) and props:
|
|
1419
|
+
prop_keys = list(props.keys())
|
|
1420
|
+
existing_required = node.get("required") or []
|
|
1421
|
+
missing = [k for k in prop_keys if k not in existing_required]
|
|
1422
|
+
node["required"] = existing_required + missing
|
|
1423
|
+
|
|
1424
|
+
for key in ("properties", "patternProperties"):
|
|
1425
|
+
if isinstance(node.get(key), dict):
|
|
1426
|
+
for sub in node[key].values():
|
|
1427
|
+
visit(sub)
|
|
1428
|
+
|
|
1429
|
+
if t == "array" and isinstance(node.get("items"), (dict, list)):
|
|
1430
|
+
visit(node["items"])
|
|
1431
|
+
|
|
1432
|
+
for key in ("anyOf", "allOf", "oneOf"):
|
|
1433
|
+
if isinstance(node.get(key), list):
|
|
1434
|
+
for sub in node[key]:
|
|
1435
|
+
visit(sub)
|
|
1436
|
+
|
|
1437
|
+
for key in ("$defs", "definitions"):
|
|
1438
|
+
if isinstance(node.get(key), dict):
|
|
1439
|
+
for sub in node[key].values():
|
|
1440
|
+
visit(sub)
|
|
1441
|
+
|
|
1442
|
+
elif isinstance(node, list):
|
|
1443
|
+
for item in node:
|
|
1444
|
+
visit(item)
|
|
1445
|
+
|
|
1446
|
+
return node
|
|
1447
|
+
|
|
1448
|
+
return visit(dict(schema))
|
|
1449
|
+
|
|
1450
|
+
def _build_response_format_from(self, output_config):
|
|
1451
|
+
"""
|
|
1452
|
+
Build a valid OpenAI response_format payload from a StructuredOutputConfig
|
|
1453
|
+
or a direct Pydantic/dataclass type. Ensures additionalProperties:false.
|
|
1454
|
+
"""
|
|
1455
|
+
if not output_config:
|
|
1456
|
+
return None
|
|
1457
|
+
|
|
1458
|
+
# Explicit JSON-only request (no schema)
|
|
1459
|
+
fmt = getattr(output_config, "format", None)
|
|
1460
|
+
if fmt and str(fmt).lower().endswith("json_object"):
|
|
1461
|
+
return {"type": "json_object"}
|
|
1462
|
+
|
|
1463
|
+
ot = getattr(output_config, "output_type", None) or output_config
|
|
1464
|
+
|
|
1465
|
+
# Pydantic model -> JSON Schema
|
|
1466
|
+
if isinstance(ot, type) and issubclass(ot, BaseModel):
|
|
1467
|
+
raw = ot.model_json_schema()
|
|
1468
|
+
schema = self._oai_normalize_schema(raw)
|
|
1469
|
+
return {
|
|
1470
|
+
"type": "json_schema",
|
|
1471
|
+
"json_schema": {
|
|
1472
|
+
"name": getattr(output_config, "name", None) or ot.__name__,
|
|
1473
|
+
"schema": schema,
|
|
1474
|
+
"strict": True,
|
|
1475
|
+
},
|
|
1476
|
+
}
|
|
1477
|
+
# Python dataclass -> JSON Schema
|
|
1478
|
+
if is_dataclass(ot):
|
|
1479
|
+
ta = TypeAdapter(ot)
|
|
1480
|
+
raw = ta.json_schema()
|
|
1481
|
+
schema = self._oai_normalize_schema(raw)
|
|
1482
|
+
return {
|
|
1483
|
+
"type": "json_schema",
|
|
1484
|
+
"json_schema": {
|
|
1485
|
+
"name": getattr(output_config, "name", None) or ot.__name__,
|
|
1486
|
+
"schema": schema,
|
|
1487
|
+
"strict": True,
|
|
1488
|
+
},
|
|
1489
|
+
}
|
|
1490
|
+
# Fallback: at least constrain to JSON object
|
|
1491
|
+
return {"type": "json_object"}
|