ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,1568 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Union, List, Optional, TYPE_CHECKING
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
import gc
|
|
6
|
+
import os
|
|
7
|
+
import logging
|
|
8
|
+
import math
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import numpy as np
|
|
11
|
+
from ..conf import HUGGINGFACEHUB_API_TOKEN
|
|
12
|
+
from ..stores.models import Document
|
|
13
|
+
from .abstract import AbstractLoader
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from moviepy import VideoFileClip
|
|
17
|
+
from pydub import AudioSegment
|
|
18
|
+
import whisperx
|
|
19
|
+
import torch
|
|
20
|
+
from transformers import (
|
|
21
|
+
pipeline,
|
|
22
|
+
AutoModelForSeq2SeqLM,
|
|
23
|
+
AutoTokenizer,
|
|
24
|
+
WhisperProcessor,
|
|
25
|
+
WhisperForConditionalGeneration
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logging.getLogger(name='numba').setLevel(logging.WARNING)
|
|
31
|
+
logging.getLogger(name='pydub.converter').setLevel(logging.WARNING)
|
|
32
|
+
|
|
33
|
+
def extract_video_id(url):
|
|
34
|
+
parts = url.split("?v=")
|
|
35
|
+
video_id = parts[1].split("&")[0]
|
|
36
|
+
return video_id
|
|
37
|
+
|
|
38
|
+
def _fmt_srt_time(t: float) -> str:
|
|
39
|
+
hrs, rem = divmod(int(t), 3600)
|
|
40
|
+
mins, secs = divmod(rem, 60)
|
|
41
|
+
ms = int((t - int(t)) * 1000)
|
|
42
|
+
return f"{hrs:02}:{mins:02}:{secs:02},{ms:03}"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BaseVideoLoader(AbstractLoader):
|
|
46
|
+
"""
|
|
47
|
+
Generating Video transcripts from Videos.
|
|
48
|
+
"""
|
|
49
|
+
extensions: List[str] = ['.youtube']
|
|
50
|
+
encoding = 'utf-8'
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
source: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
|
|
55
|
+
tokenizer: Callable[..., Any] = None,
|
|
56
|
+
text_splitter: Callable[..., Any] = None,
|
|
57
|
+
source_type: str = 'video',
|
|
58
|
+
language: str = "en",
|
|
59
|
+
video_path: Union[str, Path] = None,
|
|
60
|
+
download_video: bool = True,
|
|
61
|
+
diarization: bool = False,
|
|
62
|
+
**kwargs
|
|
63
|
+
):
|
|
64
|
+
self._download_video: bool = download_video
|
|
65
|
+
self._diarization: bool = diarization
|
|
66
|
+
super().__init__(
|
|
67
|
+
source,
|
|
68
|
+
tokenizer=tokenizer,
|
|
69
|
+
text_splitter=text_splitter,
|
|
70
|
+
source_type=source_type,
|
|
71
|
+
**kwargs
|
|
72
|
+
)
|
|
73
|
+
if isinstance(source, str):
|
|
74
|
+
self.urls = [source]
|
|
75
|
+
else:
|
|
76
|
+
self.urls = source
|
|
77
|
+
self._task = kwargs.get('task', "automatic-speech-recognition")
|
|
78
|
+
# Topics:
|
|
79
|
+
self.topics: list = kwargs.get('topics', [])
|
|
80
|
+
self._model_size: str = kwargs.get('model_size', 'small')
|
|
81
|
+
self.summarization_model = "facebook/bart-large-cnn"
|
|
82
|
+
self._model_name: str = kwargs.get('model_name', 'whisper')
|
|
83
|
+
self._use_summary_pipeline: bool = kwargs.get('use_summary_pipeline', False)
|
|
84
|
+
|
|
85
|
+
# Lazy loading: Don't load summarizer until needed
|
|
86
|
+
# This saves ~1.6GB of VRAM when summarization is disabled
|
|
87
|
+
self._summarizer = None
|
|
88
|
+
self._summarizer_device = None
|
|
89
|
+
self._summarizer_dtype = None
|
|
90
|
+
|
|
91
|
+
# Store device info for lazy loading
|
|
92
|
+
device, _, dtype = self._get_device()
|
|
93
|
+
self._summarizer_device = device
|
|
94
|
+
self._summarizer_dtype = dtype
|
|
95
|
+
|
|
96
|
+
# language:
|
|
97
|
+
self._language = language
|
|
98
|
+
# directory:
|
|
99
|
+
if isinstance(video_path, str):
|
|
100
|
+
self._video_path = Path(video_path).resolve()
|
|
101
|
+
self._video_path = video_path
|
|
102
|
+
|
|
103
|
+
def _ensure_torch(self):
|
|
104
|
+
"""Ensure Torch is configured (lazy loading)."""
|
|
105
|
+
import torch
|
|
106
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
107
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def summarizer(self):
|
|
111
|
+
"""
|
|
112
|
+
Lazy loading property for the summarizer pipeline.
|
|
113
|
+
Only loads the model when actually needed, saving ~1.6GB VRAM.
|
|
114
|
+
"""
|
|
115
|
+
if self._summarizer is None:
|
|
116
|
+
print("[ParrotBot] Loading summarizer model (BART-large-cnn)...")
|
|
117
|
+
from transformers import (
|
|
118
|
+
pipeline,
|
|
119
|
+
AutoModelForSeq2SeqLM,
|
|
120
|
+
AutoTokenizer
|
|
121
|
+
)
|
|
122
|
+
self._ensure_torch()
|
|
123
|
+
self._summarizer = pipeline(
|
|
124
|
+
"summarization",
|
|
125
|
+
tokenizer=AutoTokenizer.from_pretrained(
|
|
126
|
+
self.summarization_model
|
|
127
|
+
),
|
|
128
|
+
model=AutoModelForSeq2SeqLM.from_pretrained(
|
|
129
|
+
self.summarization_model
|
|
130
|
+
),
|
|
131
|
+
device=self._summarizer_device,
|
|
132
|
+
torch_dtype=self._summarizer_dtype,
|
|
133
|
+
)
|
|
134
|
+
print(f"[ParrotBot] ✓ Summarizer loaded on {self._summarizer_device}")
|
|
135
|
+
return self._summarizer
|
|
136
|
+
|
|
137
|
+
@summarizer.setter
|
|
138
|
+
def summarizer(self, value):
|
|
139
|
+
"""Allow external setting of summarizer (for compatibility)."""
|
|
140
|
+
self._summarizer = value
|
|
141
|
+
|
|
142
|
+
@summarizer.deleter
|
|
143
|
+
def summarizer(self):
|
|
144
|
+
"""Delete summarizer and free VRAM."""
|
|
145
|
+
if self._summarizer is not None:
|
|
146
|
+
import torch
|
|
147
|
+
del self._summarizer
|
|
148
|
+
self._summarizer = None
|
|
149
|
+
gc.collect()
|
|
150
|
+
if self._summarizer_device.startswith('cuda'):
|
|
151
|
+
torch.cuda.empty_cache()
|
|
152
|
+
print("[ParrotBot] 🧹 Summarizer freed from VRAM")
|
|
153
|
+
|
|
154
|
+
def transcript_to_vtt(self, transcript: str, transcript_path: Path) -> str:
|
|
155
|
+
"""
|
|
156
|
+
Convert a transcript to VTT format.
|
|
157
|
+
"""
|
|
158
|
+
vtt = "WEBVTT\n\n"
|
|
159
|
+
for i, chunk in enumerate(transcript['chunks'], start=1):
|
|
160
|
+
start, end = chunk['timestamp']
|
|
161
|
+
text = chunk['text'].replace("\n", " ") # Replace newlines in text with spaces
|
|
162
|
+
|
|
163
|
+
if start is None or end is None:
|
|
164
|
+
print(f"Warning: Missing timestamp for chunk {i}, skipping this chunk.")
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Convert timestamps to WebVTT format (HH:MM:SS.MMM)
|
|
168
|
+
start_vtt = f"{int(start // 3600):02}:{int(start % 3600 // 60):02}:{int(start % 60):02}.{int(start * 1000 % 1000):03}" # noqa
|
|
169
|
+
end_vtt = f"{int(end // 3600):02}:{int(end % 3600 // 60):02}:{int(end % 60):02}.{int(end * 1000 % 1000):03}" # noqa
|
|
170
|
+
|
|
171
|
+
vtt += f"{i}\n{start_vtt} --> {end_vtt}\n{text}\n\n"
|
|
172
|
+
# Save the VTT file
|
|
173
|
+
try:
|
|
174
|
+
with open(str(transcript_path), "w") as f:
|
|
175
|
+
f.write(vtt)
|
|
176
|
+
print(f'Saved VTT File on {transcript_path}')
|
|
177
|
+
except Exception as exc:
|
|
178
|
+
print(f"Error saving VTT file: {exc}")
|
|
179
|
+
return vtt
|
|
180
|
+
|
|
181
|
+
def audio_to_srt(
|
|
182
|
+
self,
|
|
183
|
+
audio_path: Path,
|
|
184
|
+
asr=None, # expects output of get_whisper_transcript() above
|
|
185
|
+
speaker_names=None, # e.g. ["Bot","Agent","Customer"]
|
|
186
|
+
output_srt_path=None,
|
|
187
|
+
pyannote_token: str = None,
|
|
188
|
+
max_gap_s: float = 0.5,
|
|
189
|
+
max_chars: int = 90,
|
|
190
|
+
max_duration_s: float = 8.0,
|
|
191
|
+
min_speakers: int = 1,
|
|
192
|
+
max_speakers: int = 2,
|
|
193
|
+
speaker_corrections: dict = None, # Manual corrections for specific segments
|
|
194
|
+
merge_short_segments: bool = True, # Merge very short adjacent segments
|
|
195
|
+
min_segment_duration: float = 0.5, # Minimum duration for a segment
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Build an SRT subtitle string from a call recording using WhisperX-aligned words and
|
|
199
|
+
Pyannote-based diarization (speaker attribution). Optionally writes the result to disk.
|
|
200
|
+
|
|
201
|
+
This function consumes a WhisperX-style transcript (with word-level timestamps),
|
|
202
|
+
performs speaker diarization (optionally constrained to a given speaker count),
|
|
203
|
+
assigns speakers to words, and groups words into readable subtitle segments with
|
|
204
|
+
length, gap, and duration constraints.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
audio_path : pathlib.Path
|
|
209
|
+
Path to the audio file used for diarization (e.g., preconverted mono 16 kHz WAV).
|
|
210
|
+
Even if `asr` is provided, this file is required to run the diarization pipeline.
|
|
211
|
+
asr : dict, optional
|
|
212
|
+
WhisperX transcript object containing aligned segments and words.
|
|
213
|
+
Expected schema:
|
|
214
|
+
{
|
|
215
|
+
"text": "...",
|
|
216
|
+
"language": "en",
|
|
217
|
+
"chunks": [
|
|
218
|
+
{
|
|
219
|
+
"text": "utterance text",
|
|
220
|
+
"timestamp": (start: float, end: float),
|
|
221
|
+
"words": [
|
|
222
|
+
{"word": "Hello", "start": 0.50, "end": 0.72},
|
|
223
|
+
...
|
|
224
|
+
]
|
|
225
|
+
},
|
|
226
|
+
...
|
|
227
|
+
]
|
|
228
|
+
}
|
|
229
|
+
If None or missing `chunks`, a ValueError is raised.
|
|
230
|
+
speaker_names : list[str] | tuple[str] | None, optional
|
|
231
|
+
Friendly labels to apply to speakers in **first-appearance order** after diarization.
|
|
232
|
+
For example, `["Bot", "Agent", "Customer"]`. If not provided, WhisperX/Pyannote
|
|
233
|
+
speaker IDs (e.g., "SPEAKER_00") are used as-is. If the number of detected
|
|
234
|
+
speakers exceeds this list, remaining speakers retain their original IDs.
|
|
235
|
+
output_srt_path : str | pathlib.Path | None, optional
|
|
236
|
+
If provided, the generated SRT text is written to this path (UTF-8). If omitted,
|
|
237
|
+
nothing is written to disk.
|
|
238
|
+
pyannote_token : str | None, optional
|
|
239
|
+
Hugging Face access token used by Pyannote diarization models. If not provided,
|
|
240
|
+
the function attempts to read it from the `PYANNOTE_AUDIO_AUTH` environment variable.
|
|
241
|
+
Required for diarization.
|
|
242
|
+
max_gap_s : float, default=0.5
|
|
243
|
+
Maximum allowed *silence* between consecutive words when aggregating them into a
|
|
244
|
+
single SRT subtitle line. A larger value yields longer lines; a smaller value
|
|
245
|
+
creates more, shorter lines.
|
|
246
|
+
max_chars : int, default=90
|
|
247
|
+
Soft limit on the number of characters per SRT subtitle line. When adding the next
|
|
248
|
+
word would exceed this threshold, a new subtitle block is started.
|
|
249
|
+
max_duration_s : float, default=8.0
|
|
250
|
+
Maximum duration (seconds) permitted for a single subtitle block. If adding the next
|
|
251
|
+
word would exceed this duration, a new block is started.
|
|
252
|
+
min_speakers : int, default=1
|
|
253
|
+
Lower bound on the number of speakers provided to the diarization pipeline.
|
|
254
|
+
Useful to avoid the "everything merges into one speaker" failure mode.
|
|
255
|
+
max_speakers : int, default=2
|
|
256
|
+
Upper bound on the number of speakers provided to the diarization pipeline.
|
|
257
|
+
Set both `min_speakers` and `max_speakers` to the exact expected number (e.g., 3)
|
|
258
|
+
to force a fixed speaker count.
|
|
259
|
+
speaker_corrections : dict | None, optional
|
|
260
|
+
Mapping to apply manual, deterministic speaker fixes after diarization and before
|
|
261
|
+
SRT grouping. The expected shape is flexible, but a common pattern is:
|
|
262
|
+
{
|
|
263
|
+
# remap entire diarized IDs
|
|
264
|
+
"SPEAKER_00": "Bot",
|
|
265
|
+
# or time-bounded corrections
|
|
266
|
+
(start_s, end_s): "Customer"
|
|
267
|
+
}
|
|
268
|
+
When keys are tuples (start, end), any words whose timestamps fall within that
|
|
269
|
+
interval are reassigned to the specified label/ID.
|
|
270
|
+
merge_short_segments : bool, default=True
|
|
271
|
+
If True, very short adjacent subtitle segments (e.g., created by rapid speaker
|
|
272
|
+
switches or punctuation) may be merged when safe (same speaker, small gap, within
|
|
273
|
+
`max_chars` and `max_duration_s`), improving readability.
|
|
274
|
+
min_segment_duration : float, default=0.5
|
|
275
|
+
Minimum duration (seconds) target when merging very short segments. Only applies
|
|
276
|
+
if `merge_short_segments=True`.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
str
|
|
281
|
+
A UTF-8 string containing the SRT-formatted transcript with speaker labels, where
|
|
282
|
+
each subtitle block follows the standard:
|
|
283
|
+
<index>
|
|
284
|
+
HH:MM:SS,mmm --> HH:MM:SS,mmm
|
|
285
|
+
<Speaker>: <text>
|
|
286
|
+
If `output_srt_path` is provided, the same content is also written to that file.
|
|
287
|
+
|
|
288
|
+
Raises
|
|
289
|
+
------
|
|
290
|
+
ValueError
|
|
291
|
+
If `asr` is None or does not contain a `chunks` list with valid timestamps.
|
|
292
|
+
RuntimeError
|
|
293
|
+
If the diarization pipeline cannot be initialized (e.g., missing `pyannote_token`)
|
|
294
|
+
or if internal alignment/speaker assignment fails unexpectedly.
|
|
295
|
+
FileNotFoundError
|
|
296
|
+
If `audio_path` does not exist.
|
|
297
|
+
|
|
298
|
+
Notes
|
|
299
|
+
-----
|
|
300
|
+
- **Word-level accuracy**: Speaker assignment happens per word (not per sentence),
|
|
301
|
+
allowing accurate handling of interruptions and fast turn-taking.
|
|
302
|
+
- **Speaker mapping**: If `speaker_names` is provided, the first diarized speaker to
|
|
303
|
+
appear in time is mapped to `speaker_names[0]`, the second to `[1]`, etc.
|
|
304
|
+
- **Determinism**: Pyannote diarization can be non-deterministic across environments.
|
|
305
|
+
Pinning dependency versions and disabling/controlling TF32 may help reproducibility.
|
|
306
|
+
- **Performance**: On low-VRAM systems, consider running diarization on CPU while
|
|
307
|
+
keeping ASR/alignment on GPU. The function itself is agnostic to device placement
|
|
308
|
+
as long as the underlying pipeline is configured accordingly.
|
|
309
|
+
|
|
310
|
+
Examples
|
|
311
|
+
--------
|
|
312
|
+
Basic usage with forced 3 speakers and file output:
|
|
313
|
+
|
|
314
|
+
>>> srt = self.audio_to_srt(
|
|
315
|
+
... audio_path=Path("call_16k_mono.wav"),
|
|
316
|
+
... asr=transcript, # from get_whisper_transcript() / WhisperX
|
|
317
|
+
... speaker_names=["Bot", "Agent", "Customer"],
|
|
318
|
+
... output_srt_path="call.srt",
|
|
319
|
+
... pyannote_token=os.environ["PYANNOTE_AUDIO_AUTH"],
|
|
320
|
+
... min_speakers=3, max_speakers=3,
|
|
321
|
+
... )
|
|
322
|
+
|
|
323
|
+
Apply manual speaker correction for the first 8 seconds as "Bot":
|
|
324
|
+
|
|
325
|
+
>>> srt = self.audio_to_srt(
|
|
326
|
+
... audio_path=Path("call.wav"),
|
|
327
|
+
... asr=transcript,
|
|
328
|
+
... pyannote_token=token,
|
|
329
|
+
... speaker_corrections={(0.0, 8.0): "Bot"},
|
|
330
|
+
... )
|
|
331
|
+
|
|
332
|
+
Tighter line grouping (shorter blocks):
|
|
333
|
+
|
|
334
|
+
>>> srt = self.audio_to_srt(
|
|
335
|
+
... audio_path=Path("call.wav"),
|
|
336
|
+
... asr=transcript,
|
|
337
|
+
... pyannote_token=token,
|
|
338
|
+
... max_gap_s=0.35, max_chars=70, max_duration_s=6.0,
|
|
339
|
+
... )
|
|
340
|
+
"""
|
|
341
|
+
def _safe_float(x):
|
|
342
|
+
try:
|
|
343
|
+
xf = float(x)
|
|
344
|
+
if math.isfinite(xf):
|
|
345
|
+
return xf
|
|
346
|
+
except Exception:
|
|
347
|
+
pass
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
if not asr or not asr.get("chunks"):
|
|
351
|
+
raise ValueError(
|
|
352
|
+
"audio_to_srt requires the WhisperX transcript (chunks with words)."
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
import whisperx
|
|
356
|
+
import torch
|
|
357
|
+
|
|
358
|
+
# Use the existing _get_device method
|
|
359
|
+
pipeline_idx, _, _ = self._get_device()
|
|
360
|
+
# Determine device string for WhisperX/pyannote
|
|
361
|
+
if isinstance(pipeline_idx, str):
|
|
362
|
+
# MPS or other special device
|
|
363
|
+
device = pipeline_idx
|
|
364
|
+
elif pipeline_idx >= 0:
|
|
365
|
+
# CUDA device
|
|
366
|
+
device = f"cuda:{pipeline_idx}"
|
|
367
|
+
else:
|
|
368
|
+
# CPU
|
|
369
|
+
device = "cpu"
|
|
370
|
+
|
|
371
|
+
token = pyannote_token or HUGGINGFACEHUB_API_TOKEN
|
|
372
|
+
if not token:
|
|
373
|
+
raise RuntimeError(
|
|
374
|
+
"Missing PYANNOTE token. Set PYANNOTE_AUDIO_AUTH or pass pyannote_token=..."
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# 1) Run WhisperX diarization on the file
|
|
378
|
+
try:
|
|
379
|
+
diarizer = whisperx.diarize.DiarizationPipeline(
|
|
380
|
+
use_auth_token=token,
|
|
381
|
+
device=device
|
|
382
|
+
)
|
|
383
|
+
except Exception as e:
|
|
384
|
+
if "mps" in str(e).lower() and device == "mps":
|
|
385
|
+
print(f"[WhisperX] MPS diarization failed ({e}), falling back to CPU")
|
|
386
|
+
device = "cpu"
|
|
387
|
+
diarizer = whisperx.diarize.DiarizationPipeline(
|
|
388
|
+
use_auth_token=token,
|
|
389
|
+
device=device
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
raise
|
|
393
|
+
|
|
394
|
+
if speaker_names and len(speaker_names) > 1:
|
|
395
|
+
min_speakers = max(2, len(speaker_names) - 1)
|
|
396
|
+
max_speakers = len(speaker_names) + 1
|
|
397
|
+
diar = diarizer(
|
|
398
|
+
str(audio_path),
|
|
399
|
+
min_speakers=min_speakers,
|
|
400
|
+
max_speakers=max_speakers,
|
|
401
|
+
)
|
|
402
|
+
# 2) Build segments for speaker assignment
|
|
403
|
+
segments = []
|
|
404
|
+
for ch in asr["chunks"]:
|
|
405
|
+
s, e = ch.get("timestamp") or (None, None)
|
|
406
|
+
s = _safe_float(s)
|
|
407
|
+
e = _safe_float(e)
|
|
408
|
+
if s is None or e is None or e <= s:
|
|
409
|
+
continue
|
|
410
|
+
seg_words = []
|
|
411
|
+
for w in ch.get("words") or []:
|
|
412
|
+
ws = _safe_float(w.get("start"))
|
|
413
|
+
we = _safe_float(w.get("end"))
|
|
414
|
+
token = (w.get("word") or "").strip()
|
|
415
|
+
if ws is None or we is None or we <= ws or not token:
|
|
416
|
+
continue
|
|
417
|
+
seg_words.append({"word": token, "start": ws, "end": we})
|
|
418
|
+
segments.append({
|
|
419
|
+
"start": s,
|
|
420
|
+
"end": e,
|
|
421
|
+
"text": ch.get("text") or "",
|
|
422
|
+
"words": seg_words
|
|
423
|
+
})
|
|
424
|
+
|
|
425
|
+
# Assign speakers to words
|
|
426
|
+
assigned = whisperx.assign_word_speakers(diar, {"segments": segments})
|
|
427
|
+
segments = assigned.get("segments", [])
|
|
428
|
+
|
|
429
|
+
# 3) Detect speaker changes and apply corrections
|
|
430
|
+
speaker_segments = self._detect_speaker_segments(segments, min_segment_duration)
|
|
431
|
+
|
|
432
|
+
# Apply manual corrections if provided
|
|
433
|
+
if speaker_corrections:
|
|
434
|
+
speaker_segments = self._apply_speaker_corrections(
|
|
435
|
+
speaker_segments,
|
|
436
|
+
speaker_corrections
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# 4) Map speakers to names
|
|
440
|
+
sp_map = self._create_speaker_mapping(
|
|
441
|
+
speaker_segments,
|
|
442
|
+
speaker_names
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# 5) Generate SRT with improved speaker labels
|
|
446
|
+
srt_lines = self._generate_srt_lines(
|
|
447
|
+
speaker_segments,
|
|
448
|
+
sp_map,
|
|
449
|
+
max_gap_s,
|
|
450
|
+
max_chars,
|
|
451
|
+
max_duration_s,
|
|
452
|
+
merge_short_segments
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
srt_text = ("\n".join(srt_lines) + "\n") if srt_lines else ""
|
|
456
|
+
|
|
457
|
+
if output_srt_path:
|
|
458
|
+
Path(output_srt_path).write_text(srt_text, encoding="utf-8")
|
|
459
|
+
|
|
460
|
+
# Cleanup
|
|
461
|
+
gc.collect()
|
|
462
|
+
if device.startswith("cuda"):
|
|
463
|
+
try:
|
|
464
|
+
torch.cuda.empty_cache()
|
|
465
|
+
except Exception:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
return srt_text
|
|
469
|
+
|
|
470
|
+
def _detect_speaker_segments(self, segments: list, min_duration: float = 0.5) -> list:
|
|
471
|
+
"""
|
|
472
|
+
Detect speaker segments with better change detection.
|
|
473
|
+
|
|
474
|
+
Groups consecutive words by the same speaker and detects speaker changes
|
|
475
|
+
based on gaps in speech or explicit speaker labels.
|
|
476
|
+
"""
|
|
477
|
+
if not segments:
|
|
478
|
+
return []
|
|
479
|
+
|
|
480
|
+
speaker_segments = []
|
|
481
|
+
current_speaker = None
|
|
482
|
+
current_start = None
|
|
483
|
+
current_end = None
|
|
484
|
+
current_words = []
|
|
485
|
+
current_text = []
|
|
486
|
+
|
|
487
|
+
for seg in segments:
|
|
488
|
+
for w in seg.get("words") or []:
|
|
489
|
+
word = w.get("word", "").strip()
|
|
490
|
+
if not word:
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
start = w.get("start")
|
|
494
|
+
end = w.get("end")
|
|
495
|
+
speaker = w.get("speaker")
|
|
496
|
+
|
|
497
|
+
if start is None or end is None:
|
|
498
|
+
continue
|
|
499
|
+
|
|
500
|
+
# Detect speaker change
|
|
501
|
+
speaker_changed = (current_speaker is not None and speaker != current_speaker)
|
|
502
|
+
|
|
503
|
+
# Detect significant gap (might indicate speaker change)
|
|
504
|
+
significant_gap = False
|
|
505
|
+
if current_end is not None:
|
|
506
|
+
gap = start - current_end
|
|
507
|
+
significant_gap = gap > 0.9
|
|
508
|
+
|
|
509
|
+
# Start new segment if speaker changed or significant gap
|
|
510
|
+
if speaker_changed or significant_gap or current_speaker is None:
|
|
511
|
+
# Save current segment if it exists
|
|
512
|
+
if current_words and current_start is not None and current_end is not None:
|
|
513
|
+
duration = current_end - current_start
|
|
514
|
+
if duration >= min_duration or len(current_words) > 3:
|
|
515
|
+
speaker_segments.append({
|
|
516
|
+
"speaker": current_speaker,
|
|
517
|
+
"start": current_start,
|
|
518
|
+
"end": current_end,
|
|
519
|
+
"words": current_words,
|
|
520
|
+
"text": " ".join(current_text)
|
|
521
|
+
})
|
|
522
|
+
|
|
523
|
+
# Start new segment
|
|
524
|
+
current_speaker = speaker
|
|
525
|
+
current_start = start
|
|
526
|
+
current_end = end
|
|
527
|
+
current_words = [w]
|
|
528
|
+
current_text = [word]
|
|
529
|
+
else:
|
|
530
|
+
# Continue current segment
|
|
531
|
+
current_end = max(current_end, end)
|
|
532
|
+
current_words.append(w)
|
|
533
|
+
current_text.append(word)
|
|
534
|
+
|
|
535
|
+
# Don't forget the last segment
|
|
536
|
+
if current_words and current_start is not None and current_end is not None:
|
|
537
|
+
speaker_segments.append({
|
|
538
|
+
"speaker": current_speaker,
|
|
539
|
+
"start": current_start,
|
|
540
|
+
"end": current_end,
|
|
541
|
+
"words": current_words,
|
|
542
|
+
"text": " ".join(current_text)
|
|
543
|
+
})
|
|
544
|
+
|
|
545
|
+
return speaker_segments
|
|
546
|
+
|
|
547
|
+
def _apply_speaker_corrections(self, segments: list, corrections: dict) -> list:
|
|
548
|
+
"""
|
|
549
|
+
Apply manual speaker corrections to specific segments.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
segments: List of speaker segments
|
|
553
|
+
corrections: Dict mapping segment index to correct speaker name
|
|
554
|
+
"""
|
|
555
|
+
for idx, correction_speaker in corrections.items():
|
|
556
|
+
if 0 <= idx < len(segments):
|
|
557
|
+
segments[idx]["speaker"] = correction_speaker
|
|
558
|
+
print(f"[Speaker Correction] Segment {idx}: -> {correction_speaker}")
|
|
559
|
+
|
|
560
|
+
return segments
|
|
561
|
+
|
|
562
|
+
def _create_speaker_mapping(self, segments: list, speaker_names: list = None) -> dict:
|
|
563
|
+
"""
|
|
564
|
+
Create mapping from speaker IDs to names.
|
|
565
|
+
|
|
566
|
+
Improved logic that better handles the initial recording message
|
|
567
|
+
and subsequent speakers.
|
|
568
|
+
"""
|
|
569
|
+
# Identify unique speakers by order of appearance
|
|
570
|
+
seen_speakers = []
|
|
571
|
+
for seg in segments:
|
|
572
|
+
sp = seg.get("speaker")
|
|
573
|
+
if sp and sp not in seen_speakers:
|
|
574
|
+
seen_speakers.append(sp)
|
|
575
|
+
|
|
576
|
+
sp_map = {}
|
|
577
|
+
|
|
578
|
+
if speaker_names:
|
|
579
|
+
# Special handling for recordings with initial disclaimer
|
|
580
|
+
# Check if first segment is very early (< 10 seconds) and might be recording
|
|
581
|
+
if segments and segments[0]["start"] < 10:
|
|
582
|
+
first_text = segments[0]["text"].lower()
|
|
583
|
+
# Common recording disclaimer patterns
|
|
584
|
+
recording_patterns = [
|
|
585
|
+
"this call is being recorded",
|
|
586
|
+
"call may be recorded",
|
|
587
|
+
"recording for quality",
|
|
588
|
+
"this conversation is being recorded"
|
|
589
|
+
]
|
|
590
|
+
|
|
591
|
+
is_recording = any(pattern in first_text for pattern in recording_patterns)
|
|
592
|
+
|
|
593
|
+
if is_recording and len(speaker_names) > len(seen_speakers):
|
|
594
|
+
# First speaker is likely the recording, use first name for it
|
|
595
|
+
if seen_speakers:
|
|
596
|
+
sp_map[seen_speakers[0]] = speaker_names[0] # "Recording" or similar
|
|
597
|
+
# Map remaining speakers starting from second name
|
|
598
|
+
for i, sp in enumerate(seen_speakers[1:], start=1):
|
|
599
|
+
if i < len(speaker_names):
|
|
600
|
+
sp_map[sp] = speaker_names[i]
|
|
601
|
+
else:
|
|
602
|
+
sp_map[sp] = f"Speaker{i}"
|
|
603
|
+
else:
|
|
604
|
+
# Standard mapping
|
|
605
|
+
for i, sp in enumerate(seen_speakers):
|
|
606
|
+
if i < len(speaker_names):
|
|
607
|
+
sp_map[sp] = speaker_names[i]
|
|
608
|
+
else:
|
|
609
|
+
sp_map[sp] = f"Speaker{i+1}"
|
|
610
|
+
else:
|
|
611
|
+
# Standard mapping for normal conversations
|
|
612
|
+
for i, sp in enumerate(seen_speakers):
|
|
613
|
+
if i < len(speaker_names):
|
|
614
|
+
sp_map[sp] = speaker_names[i]
|
|
615
|
+
else:
|
|
616
|
+
sp_map[sp] = f"Speaker{i+1}"
|
|
617
|
+
else:
|
|
618
|
+
# No names provided, use generic labels
|
|
619
|
+
for i, sp in enumerate(seen_speakers):
|
|
620
|
+
sp_map[sp] = f"Speaker{i+1}"
|
|
621
|
+
|
|
622
|
+
# Handle None speaker
|
|
623
|
+
sp_map[None] = "Unknown"
|
|
624
|
+
|
|
625
|
+
return sp_map
|
|
626
|
+
|
|
627
|
+
def _generate_srt_lines(
|
|
628
|
+
self,
|
|
629
|
+
segments: list,
|
|
630
|
+
sp_map: dict,
|
|
631
|
+
max_gap_s: float,
|
|
632
|
+
max_chars: int,
|
|
633
|
+
max_duration_s: float,
|
|
634
|
+
merge_short: bool
|
|
635
|
+
) -> list:
|
|
636
|
+
"""
|
|
637
|
+
Generate SRT lines from speaker segments.
|
|
638
|
+
"""
|
|
639
|
+
def _fmt_srt_time(t: float) -> str:
|
|
640
|
+
if t is None or not math.isfinite(t) or t < 0:
|
|
641
|
+
t = 0.0
|
|
642
|
+
ms = int(round(t * 1000.0))
|
|
643
|
+
h, ms = divmod(ms, 3600000)
|
|
644
|
+
m, ms = divmod(ms, 60000)
|
|
645
|
+
s, ms = divmod(ms, 1000)
|
|
646
|
+
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
|
647
|
+
|
|
648
|
+
srt_lines = []
|
|
649
|
+
idx = 1
|
|
650
|
+
|
|
651
|
+
for seg in segments:
|
|
652
|
+
speaker = sp_map.get(seg["speaker"], seg["speaker"] or "Unknown")
|
|
653
|
+
text = seg["text"].strip()
|
|
654
|
+
|
|
655
|
+
if not text:
|
|
656
|
+
continue
|
|
657
|
+
|
|
658
|
+
# Split long segments if needed
|
|
659
|
+
words = seg["words"]
|
|
660
|
+
if len(text) > max_chars or (seg["end"] - seg["start"]) > max_duration_s:
|
|
661
|
+
# Need to split this segment
|
|
662
|
+
sub_segments = self._split_long_segment(
|
|
663
|
+
words,
|
|
664
|
+
max_chars,
|
|
665
|
+
max_duration_s,
|
|
666
|
+
max_gap_s
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
for sub_seg in sub_segments:
|
|
670
|
+
if sub_seg["text"].strip():
|
|
671
|
+
srt_lines.append(
|
|
672
|
+
f"{idx}\n"
|
|
673
|
+
f"{_fmt_srt_time(sub_seg['start'])} --> {_fmt_srt_time(sub_seg['end'])}\n"
|
|
674
|
+
f"{speaker}: {sub_seg['text']}\n"
|
|
675
|
+
)
|
|
676
|
+
idx += 1
|
|
677
|
+
else:
|
|
678
|
+
# Use segment as is
|
|
679
|
+
srt_lines.append(
|
|
680
|
+
f"{idx}\n"
|
|
681
|
+
f"{_fmt_srt_time(seg['start'])} --> {_fmt_srt_time(seg['end'])}\n"
|
|
682
|
+
f"{speaker}: {text}\n"
|
|
683
|
+
)
|
|
684
|
+
idx += 1
|
|
685
|
+
|
|
686
|
+
return srt_lines
|
|
687
|
+
|
|
688
|
+
def _split_long_segment(
|
|
689
|
+
self,
|
|
690
|
+
words: list,
|
|
691
|
+
max_chars: int,
|
|
692
|
+
max_duration: float,
|
|
693
|
+
max_gap: float
|
|
694
|
+
) -> list:
|
|
695
|
+
"""
|
|
696
|
+
Split a long segment into smaller chunks for better readability.
|
|
697
|
+
"""
|
|
698
|
+
sub_segments = []
|
|
699
|
+
current_words = []
|
|
700
|
+
current_start = None
|
|
701
|
+
current_end = None
|
|
702
|
+
current_text = []
|
|
703
|
+
|
|
704
|
+
for w in words:
|
|
705
|
+
word = w.get("word", "").strip()
|
|
706
|
+
start = w.get("start")
|
|
707
|
+
end = w.get("end")
|
|
708
|
+
|
|
709
|
+
if not word or start is None or end is None:
|
|
710
|
+
continue
|
|
711
|
+
|
|
712
|
+
# Check if adding this word would exceed limits
|
|
713
|
+
would_exceed_chars = len(" ".join(current_text + [word])) > max_chars
|
|
714
|
+
would_exceed_duration = (
|
|
715
|
+
current_start is not None and
|
|
716
|
+
(end - current_start) > max_duration
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Check for natural break point (gap)
|
|
720
|
+
is_natural_break = False
|
|
721
|
+
if current_end is not None:
|
|
722
|
+
gap = start - current_end
|
|
723
|
+
is_natural_break = gap > max_gap
|
|
724
|
+
|
|
725
|
+
if (would_exceed_chars or would_exceed_duration or is_natural_break) and current_text:
|
|
726
|
+
# Save current sub-segment
|
|
727
|
+
sub_segments.append({
|
|
728
|
+
"start": current_start,
|
|
729
|
+
"end": current_end,
|
|
730
|
+
"text": " ".join(current_text)
|
|
731
|
+
})
|
|
732
|
+
# Start new sub-segment
|
|
733
|
+
current_words = [w]
|
|
734
|
+
current_start = start
|
|
735
|
+
current_end = end
|
|
736
|
+
current_text = [word]
|
|
737
|
+
else:
|
|
738
|
+
# Add to current sub-segment
|
|
739
|
+
if current_start is None:
|
|
740
|
+
current_start = start
|
|
741
|
+
current_end = end
|
|
742
|
+
current_words.append(w)
|
|
743
|
+
current_text.append(word)
|
|
744
|
+
|
|
745
|
+
# Don't forget the last sub-segment
|
|
746
|
+
if current_text:
|
|
747
|
+
sub_segments.append({
|
|
748
|
+
"start": current_start,
|
|
749
|
+
"end": current_end,
|
|
750
|
+
"text": " ".join(current_text)
|
|
751
|
+
})
|
|
752
|
+
|
|
753
|
+
return sub_segments
|
|
754
|
+
|
|
755
|
+
def format_timestamp(self, seconds):
|
|
756
|
+
# This helper function takes the total seconds and formats it into hh:mm:ss,ms
|
|
757
|
+
hours, remainder = divmod(int(seconds), 3600)
|
|
758
|
+
minutes, seconds = divmod(remainder, 60)
|
|
759
|
+
milliseconds = int((seconds % 1) * 1000)
|
|
760
|
+
seconds = int(seconds)
|
|
761
|
+
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
|
|
762
|
+
|
|
763
|
+
def transcript_to_blocks(self, transcript: str) -> list:
|
|
764
|
+
"""
|
|
765
|
+
Convert a transcript to blocks.
|
|
766
|
+
"""
|
|
767
|
+
blocks = []
|
|
768
|
+
for i, chunk in enumerate(transcript['chunks'], start=1):
|
|
769
|
+
current_window = {}
|
|
770
|
+
start, end = chunk['timestamp']
|
|
771
|
+
if start is None or end is None:
|
|
772
|
+
print(f"Warning: Missing timestamp for chunk {i}, skipping this chunk.")
|
|
773
|
+
continue
|
|
774
|
+
|
|
775
|
+
start_srt = self.format_timestamp(start)
|
|
776
|
+
end_srt = self.format_timestamp(end)
|
|
777
|
+
text = chunk['text'].replace("\n", " ") # Replace newlines in text with spaces
|
|
778
|
+
current_window['id'] = i
|
|
779
|
+
current_window['start_time'] = start_srt
|
|
780
|
+
current_window['end_time'] = end_srt
|
|
781
|
+
current_window['text'] = text
|
|
782
|
+
blocks.append(current_window)
|
|
783
|
+
return blocks
|
|
784
|
+
|
|
785
|
+
def chunk_text(self, text, chunk_size, tokenizer):
|
|
786
|
+
# Tokenize the text and get the number of tokens
|
|
787
|
+
tokens = tokenizer.tokenize(text)
|
|
788
|
+
# Split the tokens into chunks
|
|
789
|
+
for i in range(0, len(tokens), chunk_size):
|
|
790
|
+
yield tokenizer.convert_tokens_to_string(
|
|
791
|
+
tokens[i:i+chunk_size]
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
def extract_audio(
|
|
795
|
+
self,
|
|
796
|
+
video_path: Path,
|
|
797
|
+
audio_path: Path,
|
|
798
|
+
compress_speed: bool = False,
|
|
799
|
+
output_path: Optional[Path] = None,
|
|
800
|
+
speed_factor: float = 1.5
|
|
801
|
+
):
|
|
802
|
+
"""
|
|
803
|
+
Extract audio from video. Prefer WAV 16k mono for Whisper.
|
|
804
|
+
"""
|
|
805
|
+
video_path = Path(video_path)
|
|
806
|
+
audio_path = Path(audio_path)
|
|
807
|
+
|
|
808
|
+
if audio_path.exists():
|
|
809
|
+
print(f"Audio already extracted: {audio_path}")
|
|
810
|
+
return
|
|
811
|
+
|
|
812
|
+
# Extract as WAV 16k mono PCM
|
|
813
|
+
print(f"Extracting audio (16k mono WAV) to: {audio_path}")
|
|
814
|
+
from moviepy import VideoFileClip
|
|
815
|
+
from pydub import AudioSegment
|
|
816
|
+
clip = VideoFileClip(str(video_path))
|
|
817
|
+
if not clip.audio:
|
|
818
|
+
print("No audio found in video.")
|
|
819
|
+
clip.close()
|
|
820
|
+
return
|
|
821
|
+
|
|
822
|
+
# moviepy/ffmpeg: pcm_s16le, 16k, mono
|
|
823
|
+
# Ensure audio_path has .wav
|
|
824
|
+
if audio_path.suffix.lower() != ".wav":
|
|
825
|
+
audio_path = audio_path.with_suffix(".wav")
|
|
826
|
+
|
|
827
|
+
clip.audio.write_audiofile(
|
|
828
|
+
str(audio_path),
|
|
829
|
+
fps=16000,
|
|
830
|
+
nbytes=2,
|
|
831
|
+
codec="pcm_s16le",
|
|
832
|
+
ffmpeg_params=["-ac", "1"]
|
|
833
|
+
)
|
|
834
|
+
clip.audio.close()
|
|
835
|
+
clip.close()
|
|
836
|
+
|
|
837
|
+
# Optional speed compression (still output WAV @16k mono)
|
|
838
|
+
if compress_speed:
|
|
839
|
+
print(f"Compressing audio speed by factor: {speed_factor}")
|
|
840
|
+
audio = AudioSegment.from_file(audio_path)
|
|
841
|
+
sped = audio._spawn(audio.raw_data, overrides={"frame_rate": int(audio.frame_rate * speed_factor)})
|
|
842
|
+
sped = sped.set_frame_rate(16000).set_channels(1).set_sample_width(2)
|
|
843
|
+
sped.export(str(output_path or audio_path), format="wav")
|
|
844
|
+
print(f"Compressed audio saved to: {output_path or audio_path}")
|
|
845
|
+
else:
|
|
846
|
+
print(f"Audio extracted: {audio_path}")
|
|
847
|
+
|
|
848
|
+
def ensure_wav_16k_mono(self, src_path: Path) -> Path:
|
|
849
|
+
"""
|
|
850
|
+
Ensure `src_path` is a 16 kHz mono PCM WAV. Returns the WAV path.
|
|
851
|
+
- If src is not a .wav, write <stem>.wav
|
|
852
|
+
- If src is already .wav, write <stem>.16k.wav to avoid in-place overwrite
|
|
853
|
+
"""
|
|
854
|
+
from pydub import AudioSegment
|
|
855
|
+
src_path = Path(src_path)
|
|
856
|
+
if src_path.suffix.lower() == ".wav":
|
|
857
|
+
out_path = src_path.with_name(f"{src_path.stem}.16k.wav")
|
|
858
|
+
else:
|
|
859
|
+
out_path = src_path.with_suffix(".wav")
|
|
860
|
+
|
|
861
|
+
# Always (re)encode to guarantee 16k mono PCM s16le
|
|
862
|
+
audio = AudioSegment.from_file(src_path)
|
|
863
|
+
audio = (
|
|
864
|
+
audio.set_frame_rate(16000) # 16 kHz
|
|
865
|
+
.set_channels(1) # mono
|
|
866
|
+
.set_sample_width(2) # s16le
|
|
867
|
+
)
|
|
868
|
+
audio.export(str(out_path), format="wav")
|
|
869
|
+
print(f"Transcoded to 16k mono WAV: {out_path}")
|
|
870
|
+
return out_path
|
|
871
|
+
|
|
872
|
+
def _get_whisperx_name(self, language: str = 'en', model_size: str = 'small', version: str = 'v3'):
|
|
873
|
+
"""
|
|
874
|
+
Get the appropriate WhisperX model name based on language and size.
|
|
875
|
+
|
|
876
|
+
WhisperX model naming conventions:
|
|
877
|
+
- English-only models: "tiny.en", "base.en", "small.en", "medium.en"
|
|
878
|
+
- Multilingual models: "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "large-v3", "turbo"
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
language: Language code (e.g., "en", "es", "fr")
|
|
882
|
+
model_size: Model size ("tiny", "base", "small", "medium", "large", "turbo")
|
|
883
|
+
model_name: Explicit model name to use (overrides size selection)
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
tuple: (model_name, detected_language)
|
|
887
|
+
"""
|
|
888
|
+
if language.lower() == 'en' and model_size.lower() in {'tiny', 'base', 'small', 'medium'}:
|
|
889
|
+
return f"{model_size}.en"
|
|
890
|
+
elif model_size.lower() in {'tiny', 'base', 'small', 'medium'}:
|
|
891
|
+
return f"{model_size}"
|
|
892
|
+
else:
|
|
893
|
+
return f"{model_size}-{version}"
|
|
894
|
+
|
|
895
|
+
def get_whisperx_transcript(
|
|
896
|
+
self,
|
|
897
|
+
audio_path: Path,
|
|
898
|
+
language: str = "en",
|
|
899
|
+
model_name: str = None,
|
|
900
|
+
batch_size: int = 8,
|
|
901
|
+
compute_type_gpu: str = "float16",
|
|
902
|
+
compute_type_cpu: str = "int8"
|
|
903
|
+
):
|
|
904
|
+
"""
|
|
905
|
+
WhisperX-based transcription with word-level timestamps.
|
|
906
|
+
Returns:
|
|
907
|
+
{
|
|
908
|
+
"text": "...",
|
|
909
|
+
"chunks": [
|
|
910
|
+
{
|
|
911
|
+
"text": "...",
|
|
912
|
+
"timestamp": (start, end),
|
|
913
|
+
"words": [{"word":"...", "start":..., "end":...}, ...]
|
|
914
|
+
},
|
|
915
|
+
...
|
|
916
|
+
],
|
|
917
|
+
"language": "en"
|
|
918
|
+
}
|
|
919
|
+
"""
|
|
920
|
+
def _safe_float(x):
|
|
921
|
+
try:
|
|
922
|
+
xf = float(x)
|
|
923
|
+
if math.isfinite(xf):
|
|
924
|
+
return xf
|
|
925
|
+
except Exception:
|
|
926
|
+
pass
|
|
927
|
+
return None
|
|
928
|
+
|
|
929
|
+
# Lazy load whisperx (only when needed)
|
|
930
|
+
import whisperx
|
|
931
|
+
import torch
|
|
932
|
+
|
|
933
|
+
# Use the existing _get_device method
|
|
934
|
+
pipeline_idx, _, _ = self._get_device()
|
|
935
|
+
# Determine device string for WhisperX
|
|
936
|
+
if isinstance(pipeline_idx, str):
|
|
937
|
+
# MPS or other special device
|
|
938
|
+
device = pipeline_idx
|
|
939
|
+
elif pipeline_idx >= 0:
|
|
940
|
+
# CUDA device
|
|
941
|
+
device = "cuda"
|
|
942
|
+
else:
|
|
943
|
+
# CPU
|
|
944
|
+
device = "cpu"
|
|
945
|
+
|
|
946
|
+
# Select compute type based on device
|
|
947
|
+
if device.startswith("cuda"):
|
|
948
|
+
compute_type = compute_type_gpu
|
|
949
|
+
elif device == "mps":
|
|
950
|
+
# MPS typically works better with float32
|
|
951
|
+
compute_type = "float32"
|
|
952
|
+
else:
|
|
953
|
+
compute_type = compute_type_cpu
|
|
954
|
+
|
|
955
|
+
# Model selection
|
|
956
|
+
lang = (language or self._language).lower()
|
|
957
|
+
|
|
958
|
+
if model_name:
|
|
959
|
+
model_id = model_name
|
|
960
|
+
else:
|
|
961
|
+
model_id = self._get_whisperx_name(lang, self._model_size)
|
|
962
|
+
|
|
963
|
+
# 1) ASR
|
|
964
|
+
model = whisperx.load_model(
|
|
965
|
+
model_id,
|
|
966
|
+
device=device,
|
|
967
|
+
compute_type=compute_type,
|
|
968
|
+
language=language
|
|
969
|
+
)
|
|
970
|
+
audio = whisperx.load_audio(str(audio_path))
|
|
971
|
+
asr_result = model.transcribe(audio, batch_size=batch_size)
|
|
972
|
+
lang = asr_result.get("language", language)
|
|
973
|
+
segs = asr_result.get("segments", []) or []
|
|
974
|
+
|
|
975
|
+
# 2) Alignment → precise word times
|
|
976
|
+
align_model, align_meta = whisperx.load_align_model(
|
|
977
|
+
language_code=asr_result.get("language", language), device=device
|
|
978
|
+
)
|
|
979
|
+
aligned = whisperx.align(
|
|
980
|
+
segs,
|
|
981
|
+
align_model,
|
|
982
|
+
align_meta,
|
|
983
|
+
audio,
|
|
984
|
+
device=device,
|
|
985
|
+
return_char_alignments=False
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
# build the return payload in your existing schema
|
|
989
|
+
chunks = []
|
|
990
|
+
full_text_parts = []
|
|
991
|
+
for seg in aligned.get("segments", []):
|
|
992
|
+
s = _safe_float(seg.get("start"))
|
|
993
|
+
e = _safe_float(seg.get("end"))
|
|
994
|
+
if s is None or e is None or e <= s:
|
|
995
|
+
continue
|
|
996
|
+
text = (seg.get("text") or "").strip()
|
|
997
|
+
words_out = []
|
|
998
|
+
for w in seg.get("words") or []:
|
|
999
|
+
ws = _safe_float(w.get("start"))
|
|
1000
|
+
we = _safe_float(w.get("end"))
|
|
1001
|
+
token = (w.get("word") or "").strip()
|
|
1002
|
+
if ws is None or we is None or we <= ws or not token:
|
|
1003
|
+
continue
|
|
1004
|
+
words_out.append({"word": token, "start": ws, "end": we})
|
|
1005
|
+
chunks.append({"text": text, "timestamp": (s, e), "words": words_out})
|
|
1006
|
+
if text:
|
|
1007
|
+
full_text_parts.append(text)
|
|
1008
|
+
|
|
1009
|
+
# Cleanup
|
|
1010
|
+
del model
|
|
1011
|
+
del align_model
|
|
1012
|
+
gc.collect()
|
|
1013
|
+
try:
|
|
1014
|
+
if device.startswith("cuda"):
|
|
1015
|
+
torch.cuda.empty_cache()
|
|
1016
|
+
except Exception:
|
|
1017
|
+
pass
|
|
1018
|
+
|
|
1019
|
+
return {"text": " ".join(full_text_parts).strip(), "chunks": chunks, "language": lang}
|
|
1020
|
+
|
|
1021
|
+
def get_whisper_transcript(
|
|
1022
|
+
self,
|
|
1023
|
+
audio_path: Path,
|
|
1024
|
+
chunk_length: int = 30,
|
|
1025
|
+
word_timestamps: bool = False,
|
|
1026
|
+
manual_chunk: bool = True, # New parameter to enable manual chunking
|
|
1027
|
+
max_chunk_duration: int = 60 # Maximum seconds per chunk for GPU processing
|
|
1028
|
+
):
|
|
1029
|
+
"""
|
|
1030
|
+
Enhanced Whisper transcription with manual chunking for GPU memory management.
|
|
1031
|
+
|
|
1032
|
+
The key insight: We process smaller audio segments independently on GPU,
|
|
1033
|
+
then merge results with corrected timestamps based on each chunk's offset.
|
|
1034
|
+
"""
|
|
1035
|
+
import soundfile
|
|
1036
|
+
# Model selection
|
|
1037
|
+
lang = (self._language or "en").lower()
|
|
1038
|
+
if self._model_name in (None, "", "whisper", "openai/whisper"):
|
|
1039
|
+
size = (self._model_size or "small").lower()
|
|
1040
|
+
if lang == "en" and size in {"tiny", "base", "small", "medium"}:
|
|
1041
|
+
model_id = f"openai/whisper-{size}.en"
|
|
1042
|
+
elif size == "turbo":
|
|
1043
|
+
model_id = "openai/whisper-large-v3-turbo"
|
|
1044
|
+
else:
|
|
1045
|
+
model_id = "openai/whisper-large-v3"
|
|
1046
|
+
else:
|
|
1047
|
+
model_id = self._model_name
|
|
1048
|
+
|
|
1049
|
+
# Load audio once
|
|
1050
|
+
if not (audio_path.exists() and audio_path.stat().st_size > 0):
|
|
1051
|
+
return None
|
|
1052
|
+
|
|
1053
|
+
wav, sr = soundfile.read(str(audio_path), always_2d=False)
|
|
1054
|
+
if wav.ndim == 2:
|
|
1055
|
+
wav = wav.mean(axis=1)
|
|
1056
|
+
wav = wav.astype(np.float32, copy=False)
|
|
1057
|
+
|
|
1058
|
+
total_duration = len(wav) / float(sr)
|
|
1059
|
+
print(f"[Whisper] Total audio duration: {total_duration:.2f} seconds")
|
|
1060
|
+
|
|
1061
|
+
# Device configuration
|
|
1062
|
+
device_idx, dev, torch_dtype = self._get_device()
|
|
1063
|
+
# Special handling for MPS or other non-standard devices
|
|
1064
|
+
if isinstance(device_idx, str):
|
|
1065
|
+
# MPS or other special case - treat as CPU for pipeline purposes
|
|
1066
|
+
pipeline_device_idx = -1
|
|
1067
|
+
print(
|
|
1068
|
+
f"[Whisper] Using {device_idx} device (will use CPU pipeline mode)"
|
|
1069
|
+
)
|
|
1070
|
+
else:
|
|
1071
|
+
pipeline_device_idx = device_idx
|
|
1072
|
+
|
|
1073
|
+
# Determine if we need manual chunking
|
|
1074
|
+
# Rule of thumb: whisper-medium needs ~6GB for 60s of audio
|
|
1075
|
+
needs_manual_chunk = (
|
|
1076
|
+
manual_chunk and
|
|
1077
|
+
isinstance(device_idx, int) and device_idx >= 0 and # Using GPU
|
|
1078
|
+
total_duration > max_chunk_duration # Audio is long
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
print('[Whisper] Using model:', model_id, 'Chunking needed: ', needs_manual_chunk)
|
|
1082
|
+
|
|
1083
|
+
if needs_manual_chunk:
|
|
1084
|
+
print(
|
|
1085
|
+
f"[Whisper] Using manual chunking strategy (chunks of {max_chunk_duration}s)"
|
|
1086
|
+
)
|
|
1087
|
+
return self._process_chunks(
|
|
1088
|
+
wav, sr, model_id, lang, pipeline_device_idx, dev, torch_dtype,
|
|
1089
|
+
max_chunk_duration, word_timestamps
|
|
1090
|
+
)
|
|
1091
|
+
else:
|
|
1092
|
+
# Use the standard pipeline for short audio or CPU processing
|
|
1093
|
+
return self._process_pipeline(
|
|
1094
|
+
wav, sr, model_id, lang, pipeline_device_idx, dev, torch_dtype,
|
|
1095
|
+
chunk_length, word_timestamps
|
|
1096
|
+
)
|
|
1097
|
+
|
|
1098
|
+
def _process_pipeline(
|
|
1099
|
+
self,
|
|
1100
|
+
wav: np.ndarray,
|
|
1101
|
+
sr: int,
|
|
1102
|
+
model_id: str,
|
|
1103
|
+
lang: str,
|
|
1104
|
+
device_idx: int,
|
|
1105
|
+
torch_dev: str,
|
|
1106
|
+
torch_dtype,
|
|
1107
|
+
chunk_length: int,
|
|
1108
|
+
word_timestamps: bool
|
|
1109
|
+
):
|
|
1110
|
+
"""Use HF pipeline's built-in chunking & timestamping."""
|
|
1111
|
+
# Lazy load transformers components (only when needed)
|
|
1112
|
+
from transformers import (
|
|
1113
|
+
pipeline,
|
|
1114
|
+
WhisperForConditionalGeneration,
|
|
1115
|
+
WhisperProcessor
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
is_english_only = (
|
|
1119
|
+
model_id.endswith('.en') or
|
|
1120
|
+
'-en' in model_id.split('/')[-1] or
|
|
1121
|
+
model_id.endswith('-en')
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
model = WhisperForConditionalGeneration.from_pretrained(
|
|
1125
|
+
model_id,
|
|
1126
|
+
attn_implementation="eager", # silence SDPA warning + future-proof
|
|
1127
|
+
torch_dtype=torch_dtype,
|
|
1128
|
+
low_cpu_mem_usage=True,
|
|
1129
|
+
).to(torch_dev)
|
|
1130
|
+
processor = WhisperProcessor.from_pretrained(model_id)
|
|
1131
|
+
|
|
1132
|
+
chunk_length = int(chunk_length) if chunk_length else 30
|
|
1133
|
+
stride = 6 if chunk_length >= 8 else max(1, chunk_length // 5)
|
|
1134
|
+
|
|
1135
|
+
asr = pipeline(
|
|
1136
|
+
task="automatic-speech-recognition",
|
|
1137
|
+
model=model,
|
|
1138
|
+
tokenizer=processor.tokenizer,
|
|
1139
|
+
feature_extractor=processor.feature_extractor,
|
|
1140
|
+
device=device_idx if device_idx >= 0 else -1,
|
|
1141
|
+
torch_dtype=torch_dtype,
|
|
1142
|
+
chunk_length_s=chunk_length,
|
|
1143
|
+
stride_length_s=stride,
|
|
1144
|
+
batch_size=1
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
# Timestamp mode
|
|
1148
|
+
ts_mode = "word" if word_timestamps else True
|
|
1149
|
+
|
|
1150
|
+
generate_kwargs = {
|
|
1151
|
+
"temperature": 0.0,
|
|
1152
|
+
"compression_ratio_threshold": 2.4,
|
|
1153
|
+
"logprob_threshold": -1.0,
|
|
1154
|
+
"no_speech_threshold": 0.6,
|
|
1155
|
+
}
|
|
1156
|
+
# Language forcing only when not English-only
|
|
1157
|
+
if not is_english_only and lang:
|
|
1158
|
+
try:
|
|
1159
|
+
generate_kwargs["language"] = lang
|
|
1160
|
+
generate_kwargs["task"] = "transcribe"
|
|
1161
|
+
except Exception:
|
|
1162
|
+
pass
|
|
1163
|
+
|
|
1164
|
+
# Let the pipeline handle attention_mask/padding
|
|
1165
|
+
out = asr(
|
|
1166
|
+
{"raw": wav, "sampling_rate": sr},
|
|
1167
|
+
return_timestamps=ts_mode,
|
|
1168
|
+
generate_kwargs=generate_kwargs,
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
chunks = out.get("chunks", [])
|
|
1172
|
+
# normalize to your return shape
|
|
1173
|
+
out['text'] = out.get("text") or " ".join(c["text"] for c in chunks)
|
|
1174
|
+
return out
|
|
1175
|
+
|
|
1176
|
+
def _process_chunks(
|
|
1177
|
+
self,
|
|
1178
|
+
wav: np.ndarray,
|
|
1179
|
+
sr: int,
|
|
1180
|
+
model_id: str,
|
|
1181
|
+
lang: str,
|
|
1182
|
+
device_idx: int,
|
|
1183
|
+
torch_dev: str,
|
|
1184
|
+
torch_dtype,
|
|
1185
|
+
max_chunk_duration: int,
|
|
1186
|
+
word_timestamps: bool,
|
|
1187
|
+
chunk_length: int = 60
|
|
1188
|
+
):
|
|
1189
|
+
"""
|
|
1190
|
+
Robust audio chunking with better error handling and memory management.
|
|
1191
|
+
|
|
1192
|
+
This version addresses several key issues:
|
|
1193
|
+
1. The 'input_ids' error by properly configuring the pipeline
|
|
1194
|
+
2. The audio format issue in fallbacks
|
|
1195
|
+
3. Memory management for smaller GPUs
|
|
1196
|
+
4. Chunk processing stability
|
|
1197
|
+
"""
|
|
1198
|
+
# Lazy load transformers components (only when needed)
|
|
1199
|
+
from transformers import (
|
|
1200
|
+
pipeline,
|
|
1201
|
+
WhisperForConditionalGeneration,
|
|
1202
|
+
WhisperProcessor
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
# For whisper-small on a 5.6GB GPU, we can use slightly larger chunks than medium
|
|
1206
|
+
# whisper-small uses ~1.5GB, leaving ~4GB for processing
|
|
1207
|
+
actual_chunk_duration = min(45, max_chunk_duration) # Can handle 45s chunks with small
|
|
1208
|
+
|
|
1209
|
+
# Set environment variable for better memory management
|
|
1210
|
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
1211
|
+
|
|
1212
|
+
# English-only models end with '.en' or contain '-en' in their name
|
|
1213
|
+
is_english_only = (
|
|
1214
|
+
model_id.endswith('.en') or
|
|
1215
|
+
'-en' in model_id.split('/')[-1] or
|
|
1216
|
+
model_id.endswith('-en')
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
print(f"[Whisper] Model type: {'English-only' if is_english_only else 'Multilingual'}")
|
|
1220
|
+
print(f"[Whisper] Using model: {model_id}")
|
|
1221
|
+
|
|
1222
|
+
chunk_samples = actual_chunk_duration * sr
|
|
1223
|
+
overlap_duration = 2 # 2 seconds overlap to avoid cutting words
|
|
1224
|
+
overlap_samples = overlap_duration * sr
|
|
1225
|
+
|
|
1226
|
+
print(f"[Whisper] Processing {len(wav)/sr:.1f}s audio in {actual_chunk_duration}s chunks")
|
|
1227
|
+
|
|
1228
|
+
all_results = []
|
|
1229
|
+
offset = 0
|
|
1230
|
+
chunk_idx = 0
|
|
1231
|
+
|
|
1232
|
+
# Load model once for all chunks (whisper-small fits comfortably in memory)
|
|
1233
|
+
print(f"[Whisper] Loading {model_id} model...")
|
|
1234
|
+
model = WhisperForConditionalGeneration.from_pretrained(
|
|
1235
|
+
model_id,
|
|
1236
|
+
attn_implementation="eager", # <= fixes SDPA warning
|
|
1237
|
+
torch_dtype=torch_dtype,
|
|
1238
|
+
low_cpu_mem_usage=True,
|
|
1239
|
+
use_safetensors=True,
|
|
1240
|
+
).to(torch_dev)
|
|
1241
|
+
processor = WhisperProcessor.from_pretrained(model_id)
|
|
1242
|
+
|
|
1243
|
+
# Base generation kwargs - we'll be careful about what we pass
|
|
1244
|
+
base_generate_kwargs = {
|
|
1245
|
+
"temperature": 0.0, # Deterministic to reduce hallucinations
|
|
1246
|
+
"compression_ratio_threshold": 2.4, # Detect repetitive text
|
|
1247
|
+
"logprob_threshold": -1.0,
|
|
1248
|
+
"no_speech_threshold": 0.6,
|
|
1249
|
+
}
|
|
1250
|
+
|
|
1251
|
+
# Only add language forcing if it's properly supported
|
|
1252
|
+
if not is_english_only:
|
|
1253
|
+
try:
|
|
1254
|
+
forced_ids = processor.get_decoder_prompt_ids(
|
|
1255
|
+
language=lang,
|
|
1256
|
+
task="transcribe"
|
|
1257
|
+
)
|
|
1258
|
+
if forced_ids:
|
|
1259
|
+
base_generate_kwargs["language"] = lang
|
|
1260
|
+
base_generate_kwargs["task"] = "transcribe"
|
|
1261
|
+
# Note: We don't pass forced_decoder_ids directly as it can cause issues
|
|
1262
|
+
except Exception:
|
|
1263
|
+
# If the processor doesn't support this, that's fine
|
|
1264
|
+
pass
|
|
1265
|
+
|
|
1266
|
+
while offset < len(wav):
|
|
1267
|
+
# Extract chunk
|
|
1268
|
+
end_sample = min(offset + chunk_samples, len(wav))
|
|
1269
|
+
chunk_wav = wav[offset:end_sample]
|
|
1270
|
+
|
|
1271
|
+
# Calculate timing for this chunk
|
|
1272
|
+
time_offset = offset / float(sr)
|
|
1273
|
+
chunk_duration = len(chunk_wav) / float(sr)
|
|
1274
|
+
|
|
1275
|
+
print(f"[Whisper] Processing chunk {chunk_idx + 1} "
|
|
1276
|
+
f"({time_offset:.1f}s - {time_offset + chunk_duration:.1f}s)")
|
|
1277
|
+
|
|
1278
|
+
# Process this chunk with careful error handling
|
|
1279
|
+
chunk_processed = False
|
|
1280
|
+
attempts = [
|
|
1281
|
+
("standard", word_timestamps),
|
|
1282
|
+
("chunk_timestamps", False), # Fallback to chunk timestamps
|
|
1283
|
+
("basic", False) # Most basic mode
|
|
1284
|
+
]
|
|
1285
|
+
chunk_length = int(chunk_length) if chunk_length else 30
|
|
1286
|
+
stride = 6 if chunk_length >= 8 else max(1, chunk_length // 5)
|
|
1287
|
+
|
|
1288
|
+
for attempt_name, use_word_timestamps in attempts:
|
|
1289
|
+
if chunk_processed:
|
|
1290
|
+
break
|
|
1291
|
+
|
|
1292
|
+
try:
|
|
1293
|
+
# Create a fresh pipeline for each chunk to avoid state issues
|
|
1294
|
+
# This is important for avoiding the 'input_ids' error
|
|
1295
|
+
asr = pipeline(
|
|
1296
|
+
task="automatic-speech-recognition",
|
|
1297
|
+
model=model,
|
|
1298
|
+
tokenizer=processor.tokenizer,
|
|
1299
|
+
feature_extractor=processor.feature_extractor,
|
|
1300
|
+
device=device_idx if device_idx >= 0 else -1,
|
|
1301
|
+
chunk_length_s=chunk_length,
|
|
1302
|
+
stride_length_s=stride,
|
|
1303
|
+
batch_size=1,
|
|
1304
|
+
torch_dtype=torch_dtype,
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
# Prepare audio input with the CORRECT format
|
|
1308
|
+
# This is crucial - the pipeline expects "raw" not "array"
|
|
1309
|
+
audio_input = {
|
|
1310
|
+
"raw": chunk_wav,
|
|
1311
|
+
"sampling_rate": sr
|
|
1312
|
+
}
|
|
1313
|
+
|
|
1314
|
+
# Determine timestamp mode based on current attempt
|
|
1315
|
+
if use_word_timestamps:
|
|
1316
|
+
timestamp_param = "word"
|
|
1317
|
+
else:
|
|
1318
|
+
timestamp_param = True # Chunk-level timestamps
|
|
1319
|
+
|
|
1320
|
+
# Use a clean copy of generate_kwargs for each attempt
|
|
1321
|
+
# This prevents accumulation of incompatible parameters
|
|
1322
|
+
generate_kwargs = base_generate_kwargs.copy()
|
|
1323
|
+
|
|
1324
|
+
# Process the chunk
|
|
1325
|
+
chunk_result = asr(
|
|
1326
|
+
audio_input,
|
|
1327
|
+
return_timestamps=timestamp_param,
|
|
1328
|
+
generate_kwargs=generate_kwargs
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
# Successfully processed - now handle the results
|
|
1332
|
+
if chunk_result and "chunks" in chunk_result:
|
|
1333
|
+
for item in chunk_result["chunks"]:
|
|
1334
|
+
# Adjust timestamps for this chunk's position
|
|
1335
|
+
if "timestamp" in item and item["timestamp"]:
|
|
1336
|
+
start, end = item["timestamp"]
|
|
1337
|
+
if start is not None:
|
|
1338
|
+
start += time_offset
|
|
1339
|
+
if end is not None:
|
|
1340
|
+
end += time_offset
|
|
1341
|
+
item["timestamp"] = (start, end)
|
|
1342
|
+
|
|
1343
|
+
# Add metadata for merging
|
|
1344
|
+
item["_chunk_idx"] = chunk_idx
|
|
1345
|
+
item["_is_word"] = use_word_timestamps
|
|
1346
|
+
|
|
1347
|
+
all_results.extend(chunk_result["chunks"])
|
|
1348
|
+
print(f" ✓ Chunk {chunk_idx + 1}: {len(chunk_result['chunks'])} items "
|
|
1349
|
+
f"(mode: {attempt_name})")
|
|
1350
|
+
chunk_processed = True
|
|
1351
|
+
|
|
1352
|
+
# Clean up the pipeline to free memory
|
|
1353
|
+
del asr
|
|
1354
|
+
gc.collect()
|
|
1355
|
+
if device_idx >= 0:
|
|
1356
|
+
torch.cuda.empty_cache()
|
|
1357
|
+
|
|
1358
|
+
except Exception as e:
|
|
1359
|
+
error_msg = str(e)
|
|
1360
|
+
print(f" ✗ Attempt '{attempt_name}' failed: {error_msg[:100]}")
|
|
1361
|
+
|
|
1362
|
+
# Clean up on error
|
|
1363
|
+
if 'asr' in locals():
|
|
1364
|
+
del asr
|
|
1365
|
+
gc.collect()
|
|
1366
|
+
if device_idx >= 0:
|
|
1367
|
+
torch.cuda.empty_cache()
|
|
1368
|
+
|
|
1369
|
+
# Continue to next attempt
|
|
1370
|
+
continue
|
|
1371
|
+
|
|
1372
|
+
if not chunk_processed:
|
|
1373
|
+
print(f" âš Chunk {chunk_idx + 1} could not be processed, skipping")
|
|
1374
|
+
|
|
1375
|
+
# Move to next chunk
|
|
1376
|
+
if end_sample < len(wav):
|
|
1377
|
+
offset += chunk_samples - overlap_samples
|
|
1378
|
+
else:
|
|
1379
|
+
break
|
|
1380
|
+
|
|
1381
|
+
chunk_idx += 1
|
|
1382
|
+
|
|
1383
|
+
# Clean up model after all chunks
|
|
1384
|
+
del model
|
|
1385
|
+
del processor
|
|
1386
|
+
gc.collect()
|
|
1387
|
+
if device_idx >= 0:
|
|
1388
|
+
torch.cuda.empty_cache()
|
|
1389
|
+
|
|
1390
|
+
# Merge results based on whether we got word or chunk timestamps
|
|
1391
|
+
# Check what we actually got (might be mixed if some chunks fell back)
|
|
1392
|
+
has_word_timestamps = any(item.get("_is_word", False) for item in all_results)
|
|
1393
|
+
|
|
1394
|
+
if has_word_timestamps:
|
|
1395
|
+
print("[Whisper] Merging word-level timestamps...")
|
|
1396
|
+
final_chunks = self._merge_word_chunks(all_results, overlap_duration)
|
|
1397
|
+
else:
|
|
1398
|
+
print("[Whisper] Merging chunk-level timestamps...")
|
|
1399
|
+
final_chunks = self._merge_overlapping_chunks(all_results, overlap_duration)
|
|
1400
|
+
|
|
1401
|
+
# Clean the results to remove any garbage/hallucinations
|
|
1402
|
+
cleaned_chunks = []
|
|
1403
|
+
for chunk in final_chunks:
|
|
1404
|
+
text = chunk.get("text", "").strip()
|
|
1405
|
+
|
|
1406
|
+
# Filter out common hallucination patterns
|
|
1407
|
+
if not text:
|
|
1408
|
+
continue
|
|
1409
|
+
if len(set(text)) < 3 and len(text) > 10: # Repetitive characters
|
|
1410
|
+
continue
|
|
1411
|
+
if text.count("$") > len(text) * 0.5: # Too many special characters
|
|
1412
|
+
continue
|
|
1413
|
+
if text.count("�") > 0: # Unicode errors
|
|
1414
|
+
continue
|
|
1415
|
+
|
|
1416
|
+
chunk["text"] = text
|
|
1417
|
+
cleaned_chunks.append(chunk)
|
|
1418
|
+
|
|
1419
|
+
# Build the final result
|
|
1420
|
+
result = {
|
|
1421
|
+
"chunks": cleaned_chunks,
|
|
1422
|
+
"text": " ".join(ch["text"] for ch in cleaned_chunks),
|
|
1423
|
+
"word_timestamps": has_word_timestamps
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
print(f"[Whisper] Transcription complete: {len(cleaned_chunks)} segments, "
|
|
1427
|
+
f"{len(result['text'].split())} words")
|
|
1428
|
+
|
|
1429
|
+
return result
|
|
1430
|
+
|
|
1431
|
+
def _merge_overlapping_chunks(self, chunks: List[dict], overlap_duration: float) -> List[dict]:
|
|
1432
|
+
"""
|
|
1433
|
+
Intelligently merge chunks that might have overlapping content.
|
|
1434
|
+
|
|
1435
|
+
When we process overlapping audio segments, we might get duplicate
|
|
1436
|
+
transcriptions at the boundaries. This function:
|
|
1437
|
+
1. Detects potential duplicates based on timestamp overlap
|
|
1438
|
+
2. Keeps the best version (usually from the chunk where it's not at the edge)
|
|
1439
|
+
3. Maintains temporal order
|
|
1440
|
+
"""
|
|
1441
|
+
if not chunks:
|
|
1442
|
+
return []
|
|
1443
|
+
|
|
1444
|
+
# Sort by start time
|
|
1445
|
+
chunks.sort(key=lambda x: x.get("timestamp", (0,))[0] or 0)
|
|
1446
|
+
|
|
1447
|
+
merged = []
|
|
1448
|
+
for chunk in chunks:
|
|
1449
|
+
if not chunk.get("text", "").strip():
|
|
1450
|
+
continue
|
|
1451
|
+
|
|
1452
|
+
timestamp = chunk.get("timestamp", (None, None))
|
|
1453
|
+
if not timestamp or timestamp[0] is None:
|
|
1454
|
+
continue
|
|
1455
|
+
|
|
1456
|
+
# Check if this chunk overlaps significantly with the last merged chunk
|
|
1457
|
+
if merged:
|
|
1458
|
+
last = merged[-1]
|
|
1459
|
+
last_ts = last.get("timestamp", (None, None))
|
|
1460
|
+
|
|
1461
|
+
if last_ts and last_ts[1] and timestamp[0]:
|
|
1462
|
+
# If timestamps overlap significantly
|
|
1463
|
+
overlap = last_ts[1] - timestamp[0]
|
|
1464
|
+
if overlap > 0.5: # More than 0.5 second overlap
|
|
1465
|
+
# Compare text similarity to detect duplicates
|
|
1466
|
+
last_text = last.get("text", "").strip().lower()
|
|
1467
|
+
curr_text = chunk.get("text", "").strip().lower()
|
|
1468
|
+
|
|
1469
|
+
# Simple duplicate detection
|
|
1470
|
+
if last_text == curr_text:
|
|
1471
|
+
# Skip this duplicate
|
|
1472
|
+
continue
|
|
1473
|
+
|
|
1474
|
+
# If texts are very similar (e.g., one is subset of another)
|
|
1475
|
+
if len(last_text) > 10 and len(curr_text) > 10:
|
|
1476
|
+
if last_text in curr_text or curr_text in last_text:
|
|
1477
|
+
# Keep the longer version
|
|
1478
|
+
if len(curr_text) > len(last_text):
|
|
1479
|
+
merged[-1] = chunk
|
|
1480
|
+
continue
|
|
1481
|
+
|
|
1482
|
+
merged.append(chunk)
|
|
1483
|
+
|
|
1484
|
+
return merged
|
|
1485
|
+
|
|
1486
|
+
def _merge_word_chunks(self, chunks: List[dict], overlap_duration: float) -> List[dict]:
|
|
1487
|
+
"""
|
|
1488
|
+
Special merging logic for word-level timestamps.
|
|
1489
|
+
|
|
1490
|
+
Word-level chunks need more careful handling because:
|
|
1491
|
+
1. Words at boundaries might appear in multiple chunks
|
|
1492
|
+
2. Timestamp precision is more important
|
|
1493
|
+
3. We need to maintain word order exactly
|
|
1494
|
+
"""
|
|
1495
|
+
if not chunks:
|
|
1496
|
+
return []
|
|
1497
|
+
|
|
1498
|
+
# Sort by start timestamp
|
|
1499
|
+
chunks.sort(key=lambda x: (x.get("timestamp", (0,))[0] or 0, x.get("_chunk_idx", 0)))
|
|
1500
|
+
|
|
1501
|
+
merged = []
|
|
1502
|
+
seen_words = set() # Track (word, approximate_time) to avoid duplicates
|
|
1503
|
+
|
|
1504
|
+
for chunk in chunks:
|
|
1505
|
+
word = chunk.get("text", "").strip()
|
|
1506
|
+
if not word:
|
|
1507
|
+
continue
|
|
1508
|
+
|
|
1509
|
+
timestamp = chunk.get("timestamp", (None, None))
|
|
1510
|
+
if not timestamp or timestamp[0] is None:
|
|
1511
|
+
continue
|
|
1512
|
+
|
|
1513
|
+
# Create a key for duplicate detection
|
|
1514
|
+
# Round timestamp to nearest 0.1s for fuzzy matching
|
|
1515
|
+
time_key = round(timestamp[0], 1)
|
|
1516
|
+
word_key = (word.lower(), time_key)
|
|
1517
|
+
|
|
1518
|
+
# Skip if we've seen this word at approximately this time
|
|
1519
|
+
if word_key in seen_words:
|
|
1520
|
+
continue
|
|
1521
|
+
|
|
1522
|
+
seen_words.add(word_key)
|
|
1523
|
+
merged.append(chunk)
|
|
1524
|
+
|
|
1525
|
+
return merged
|
|
1526
|
+
|
|
1527
|
+
def clear_cuda(self):
|
|
1528
|
+
"""
|
|
1529
|
+
Clear CUDA cache and free all GPU memory used by this loader.
|
|
1530
|
+
|
|
1531
|
+
This method:
|
|
1532
|
+
1. Deletes the summarizer pipeline if loaded
|
|
1533
|
+
2. Forces garbage collection
|
|
1534
|
+
3. Clears PyTorch CUDA cache
|
|
1535
|
+
|
|
1536
|
+
Call this method when done processing to free VRAM for other tasks.
|
|
1537
|
+
"""
|
|
1538
|
+
freed_items = []
|
|
1539
|
+
|
|
1540
|
+
# Free summarizer if it was loaded
|
|
1541
|
+
if self._summarizer is not None:
|
|
1542
|
+
del self.summarizer # Uses the deleter which handles cleanup
|
|
1543
|
+
freed_items.append("summarizer")
|
|
1544
|
+
|
|
1545
|
+
# Force garbage collection
|
|
1546
|
+
gc.collect()
|
|
1547
|
+
|
|
1548
|
+
# Clear CUDA cache if on GPU
|
|
1549
|
+
device = getattr(self, '_summarizer_device', None)
|
|
1550
|
+
if device and (device.startswith('cuda') or isinstance(device, int) and device >= 0):
|
|
1551
|
+
try:
|
|
1552
|
+
torch.cuda.empty_cache()
|
|
1553
|
+
freed_items.append("CUDA cache")
|
|
1554
|
+
except Exception as e:
|
|
1555
|
+
print(f"[ParrotBot] Warning: Failed to clear CUDA cache: {e}")
|
|
1556
|
+
|
|
1557
|
+
if freed_items:
|
|
1558
|
+
print(f"[ParrotBot] 🧹 Cleared: {', '.join(freed_items)}")
|
|
1559
|
+
else:
|
|
1560
|
+
print("[ParrotBot] 🧹 No GPU resources to clear")
|
|
1561
|
+
|
|
1562
|
+
@abstractmethod
|
|
1563
|
+
async def _load(self, source: str, **kwargs) -> List[Document]:
|
|
1564
|
+
pass
|
|
1565
|
+
|
|
1566
|
+
@abstractmethod
|
|
1567
|
+
async def load_video(self, url: str, video_title: str, transcript: str) -> list:
|
|
1568
|
+
pass
|