ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
@@ -0,0 +1,1568 @@
1
+ from __future__ import annotations
2
+ from typing import Any, Union, List, Optional, TYPE_CHECKING
3
+ from collections.abc import Callable
4
+ from abc import abstractmethod
5
+ import gc
6
+ import os
7
+ import logging
8
+ import math
9
+ from pathlib import Path
10
+ import numpy as np
11
+ from ..conf import HUGGINGFACEHUB_API_TOKEN
12
+ from ..stores.models import Document
13
+ from .abstract import AbstractLoader
14
+
15
+ if TYPE_CHECKING:
16
+ from moviepy import VideoFileClip
17
+ from pydub import AudioSegment
18
+ import whisperx
19
+ import torch
20
+ from transformers import (
21
+ pipeline,
22
+ AutoModelForSeq2SeqLM,
23
+ AutoTokenizer,
24
+ WhisperProcessor,
25
+ WhisperForConditionalGeneration
26
+ )
27
+
28
+
29
+
30
+ logging.getLogger(name='numba').setLevel(logging.WARNING)
31
+ logging.getLogger(name='pydub.converter').setLevel(logging.WARNING)
32
+
33
+ def extract_video_id(url):
34
+ parts = url.split("?v=")
35
+ video_id = parts[1].split("&")[0]
36
+ return video_id
37
+
38
+ def _fmt_srt_time(t: float) -> str:
39
+ hrs, rem = divmod(int(t), 3600)
40
+ mins, secs = divmod(rem, 60)
41
+ ms = int((t - int(t)) * 1000)
42
+ return f"{hrs:02}:{mins:02}:{secs:02},{ms:03}"
43
+
44
+
45
+ class BaseVideoLoader(AbstractLoader):
46
+ """
47
+ Generating Video transcripts from Videos.
48
+ """
49
+ extensions: List[str] = ['.youtube']
50
+ encoding = 'utf-8'
51
+
52
+ def __init__(
53
+ self,
54
+ source: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
55
+ tokenizer: Callable[..., Any] = None,
56
+ text_splitter: Callable[..., Any] = None,
57
+ source_type: str = 'video',
58
+ language: str = "en",
59
+ video_path: Union[str, Path] = None,
60
+ download_video: bool = True,
61
+ diarization: bool = False,
62
+ **kwargs
63
+ ):
64
+ self._download_video: bool = download_video
65
+ self._diarization: bool = diarization
66
+ super().__init__(
67
+ source,
68
+ tokenizer=tokenizer,
69
+ text_splitter=text_splitter,
70
+ source_type=source_type,
71
+ **kwargs
72
+ )
73
+ if isinstance(source, str):
74
+ self.urls = [source]
75
+ else:
76
+ self.urls = source
77
+ self._task = kwargs.get('task', "automatic-speech-recognition")
78
+ # Topics:
79
+ self.topics: list = kwargs.get('topics', [])
80
+ self._model_size: str = kwargs.get('model_size', 'small')
81
+ self.summarization_model = "facebook/bart-large-cnn"
82
+ self._model_name: str = kwargs.get('model_name', 'whisper')
83
+ self._use_summary_pipeline: bool = kwargs.get('use_summary_pipeline', False)
84
+
85
+ # Lazy loading: Don't load summarizer until needed
86
+ # This saves ~1.6GB of VRAM when summarization is disabled
87
+ self._summarizer = None
88
+ self._summarizer_device = None
89
+ self._summarizer_dtype = None
90
+
91
+ # Store device info for lazy loading
92
+ device, _, dtype = self._get_device()
93
+ self._summarizer_device = device
94
+ self._summarizer_dtype = dtype
95
+
96
+ # language:
97
+ self._language = language
98
+ # directory:
99
+ if isinstance(video_path, str):
100
+ self._video_path = Path(video_path).resolve()
101
+ self._video_path = video_path
102
+
103
+ def _ensure_torch(self):
104
+ """Ensure Torch is configured (lazy loading)."""
105
+ import torch
106
+ torch.backends.cuda.matmul.allow_tf32 = True
107
+ torch.backends.cudnn.allow_tf32 = True
108
+
109
+ @property
110
+ def summarizer(self):
111
+ """
112
+ Lazy loading property for the summarizer pipeline.
113
+ Only loads the model when actually needed, saving ~1.6GB VRAM.
114
+ """
115
+ if self._summarizer is None:
116
+ print("[ParrotBot] Loading summarizer model (BART-large-cnn)...")
117
+ from transformers import (
118
+ pipeline,
119
+ AutoModelForSeq2SeqLM,
120
+ AutoTokenizer
121
+ )
122
+ self._ensure_torch()
123
+ self._summarizer = pipeline(
124
+ "summarization",
125
+ tokenizer=AutoTokenizer.from_pretrained(
126
+ self.summarization_model
127
+ ),
128
+ model=AutoModelForSeq2SeqLM.from_pretrained(
129
+ self.summarization_model
130
+ ),
131
+ device=self._summarizer_device,
132
+ torch_dtype=self._summarizer_dtype,
133
+ )
134
+ print(f"[ParrotBot] ✓ Summarizer loaded on {self._summarizer_device}")
135
+ return self._summarizer
136
+
137
+ @summarizer.setter
138
+ def summarizer(self, value):
139
+ """Allow external setting of summarizer (for compatibility)."""
140
+ self._summarizer = value
141
+
142
+ @summarizer.deleter
143
+ def summarizer(self):
144
+ """Delete summarizer and free VRAM."""
145
+ if self._summarizer is not None:
146
+ import torch
147
+ del self._summarizer
148
+ self._summarizer = None
149
+ gc.collect()
150
+ if self._summarizer_device.startswith('cuda'):
151
+ torch.cuda.empty_cache()
152
+ print("[ParrotBot] 🧹 Summarizer freed from VRAM")
153
+
154
+ def transcript_to_vtt(self, transcript: str, transcript_path: Path) -> str:
155
+ """
156
+ Convert a transcript to VTT format.
157
+ """
158
+ vtt = "WEBVTT\n\n"
159
+ for i, chunk in enumerate(transcript['chunks'], start=1):
160
+ start, end = chunk['timestamp']
161
+ text = chunk['text'].replace("\n", " ") # Replace newlines in text with spaces
162
+
163
+ if start is None or end is None:
164
+ print(f"Warning: Missing timestamp for chunk {i}, skipping this chunk.")
165
+ continue
166
+
167
+ # Convert timestamps to WebVTT format (HH:MM:SS.MMM)
168
+ start_vtt = f"{int(start // 3600):02}:{int(start % 3600 // 60):02}:{int(start % 60):02}.{int(start * 1000 % 1000):03}" # noqa
169
+ end_vtt = f"{int(end // 3600):02}:{int(end % 3600 // 60):02}:{int(end % 60):02}.{int(end * 1000 % 1000):03}" # noqa
170
+
171
+ vtt += f"{i}\n{start_vtt} --> {end_vtt}\n{text}\n\n"
172
+ # Save the VTT file
173
+ try:
174
+ with open(str(transcript_path), "w") as f:
175
+ f.write(vtt)
176
+ print(f'Saved VTT File on {transcript_path}')
177
+ except Exception as exc:
178
+ print(f"Error saving VTT file: {exc}")
179
+ return vtt
180
+
181
+ def audio_to_srt(
182
+ self,
183
+ audio_path: Path,
184
+ asr=None, # expects output of get_whisper_transcript() above
185
+ speaker_names=None, # e.g. ["Bot","Agent","Customer"]
186
+ output_srt_path=None,
187
+ pyannote_token: str = None,
188
+ max_gap_s: float = 0.5,
189
+ max_chars: int = 90,
190
+ max_duration_s: float = 8.0,
191
+ min_speakers: int = 1,
192
+ max_speakers: int = 2,
193
+ speaker_corrections: dict = None, # Manual corrections for specific segments
194
+ merge_short_segments: bool = True, # Merge very short adjacent segments
195
+ min_segment_duration: float = 0.5, # Minimum duration for a segment
196
+ ):
197
+ """
198
+ Build an SRT subtitle string from a call recording using WhisperX-aligned words and
199
+ Pyannote-based diarization (speaker attribution). Optionally writes the result to disk.
200
+
201
+ This function consumes a WhisperX-style transcript (with word-level timestamps),
202
+ performs speaker diarization (optionally constrained to a given speaker count),
203
+ assigns speakers to words, and groups words into readable subtitle segments with
204
+ length, gap, and duration constraints.
205
+
206
+ Parameters
207
+ ----------
208
+ audio_path : pathlib.Path
209
+ Path to the audio file used for diarization (e.g., preconverted mono 16 kHz WAV).
210
+ Even if `asr` is provided, this file is required to run the diarization pipeline.
211
+ asr : dict, optional
212
+ WhisperX transcript object containing aligned segments and words.
213
+ Expected schema:
214
+ {
215
+ "text": "...",
216
+ "language": "en",
217
+ "chunks": [
218
+ {
219
+ "text": "utterance text",
220
+ "timestamp": (start: float, end: float),
221
+ "words": [
222
+ {"word": "Hello", "start": 0.50, "end": 0.72},
223
+ ...
224
+ ]
225
+ },
226
+ ...
227
+ ]
228
+ }
229
+ If None or missing `chunks`, a ValueError is raised.
230
+ speaker_names : list[str] | tuple[str] | None, optional
231
+ Friendly labels to apply to speakers in **first-appearance order** after diarization.
232
+ For example, `["Bot", "Agent", "Customer"]`. If not provided, WhisperX/Pyannote
233
+ speaker IDs (e.g., "SPEAKER_00") are used as-is. If the number of detected
234
+ speakers exceeds this list, remaining speakers retain their original IDs.
235
+ output_srt_path : str | pathlib.Path | None, optional
236
+ If provided, the generated SRT text is written to this path (UTF-8). If omitted,
237
+ nothing is written to disk.
238
+ pyannote_token : str | None, optional
239
+ Hugging Face access token used by Pyannote diarization models. If not provided,
240
+ the function attempts to read it from the `PYANNOTE_AUDIO_AUTH` environment variable.
241
+ Required for diarization.
242
+ max_gap_s : float, default=0.5
243
+ Maximum allowed *silence* between consecutive words when aggregating them into a
244
+ single SRT subtitle line. A larger value yields longer lines; a smaller value
245
+ creates more, shorter lines.
246
+ max_chars : int, default=90
247
+ Soft limit on the number of characters per SRT subtitle line. When adding the next
248
+ word would exceed this threshold, a new subtitle block is started.
249
+ max_duration_s : float, default=8.0
250
+ Maximum duration (seconds) permitted for a single subtitle block. If adding the next
251
+ word would exceed this duration, a new block is started.
252
+ min_speakers : int, default=1
253
+ Lower bound on the number of speakers provided to the diarization pipeline.
254
+ Useful to avoid the "everything merges into one speaker" failure mode.
255
+ max_speakers : int, default=2
256
+ Upper bound on the number of speakers provided to the diarization pipeline.
257
+ Set both `min_speakers` and `max_speakers` to the exact expected number (e.g., 3)
258
+ to force a fixed speaker count.
259
+ speaker_corrections : dict | None, optional
260
+ Mapping to apply manual, deterministic speaker fixes after diarization and before
261
+ SRT grouping. The expected shape is flexible, but a common pattern is:
262
+ {
263
+ # remap entire diarized IDs
264
+ "SPEAKER_00": "Bot",
265
+ # or time-bounded corrections
266
+ (start_s, end_s): "Customer"
267
+ }
268
+ When keys are tuples (start, end), any words whose timestamps fall within that
269
+ interval are reassigned to the specified label/ID.
270
+ merge_short_segments : bool, default=True
271
+ If True, very short adjacent subtitle segments (e.g., created by rapid speaker
272
+ switches or punctuation) may be merged when safe (same speaker, small gap, within
273
+ `max_chars` and `max_duration_s`), improving readability.
274
+ min_segment_duration : float, default=0.5
275
+ Minimum duration (seconds) target when merging very short segments. Only applies
276
+ if `merge_short_segments=True`.
277
+
278
+ Returns
279
+ -------
280
+ str
281
+ A UTF-8 string containing the SRT-formatted transcript with speaker labels, where
282
+ each subtitle block follows the standard:
283
+ <index>
284
+ HH:MM:SS,mmm --> HH:MM:SS,mmm
285
+ <Speaker>: <text>
286
+ If `output_srt_path` is provided, the same content is also written to that file.
287
+
288
+ Raises
289
+ ------
290
+ ValueError
291
+ If `asr` is None or does not contain a `chunks` list with valid timestamps.
292
+ RuntimeError
293
+ If the diarization pipeline cannot be initialized (e.g., missing `pyannote_token`)
294
+ or if internal alignment/speaker assignment fails unexpectedly.
295
+ FileNotFoundError
296
+ If `audio_path` does not exist.
297
+
298
+ Notes
299
+ -----
300
+ - **Word-level accuracy**: Speaker assignment happens per word (not per sentence),
301
+ allowing accurate handling of interruptions and fast turn-taking.
302
+ - **Speaker mapping**: If `speaker_names` is provided, the first diarized speaker to
303
+ appear in time is mapped to `speaker_names[0]`, the second to `[1]`, etc.
304
+ - **Determinism**: Pyannote diarization can be non-deterministic across environments.
305
+ Pinning dependency versions and disabling/controlling TF32 may help reproducibility.
306
+ - **Performance**: On low-VRAM systems, consider running diarization on CPU while
307
+ keeping ASR/alignment on GPU. The function itself is agnostic to device placement
308
+ as long as the underlying pipeline is configured accordingly.
309
+
310
+ Examples
311
+ --------
312
+ Basic usage with forced 3 speakers and file output:
313
+
314
+ >>> srt = self.audio_to_srt(
315
+ ... audio_path=Path("call_16k_mono.wav"),
316
+ ... asr=transcript, # from get_whisper_transcript() / WhisperX
317
+ ... speaker_names=["Bot", "Agent", "Customer"],
318
+ ... output_srt_path="call.srt",
319
+ ... pyannote_token=os.environ["PYANNOTE_AUDIO_AUTH"],
320
+ ... min_speakers=3, max_speakers=3,
321
+ ... )
322
+
323
+ Apply manual speaker correction for the first 8 seconds as "Bot":
324
+
325
+ >>> srt = self.audio_to_srt(
326
+ ... audio_path=Path("call.wav"),
327
+ ... asr=transcript,
328
+ ... pyannote_token=token,
329
+ ... speaker_corrections={(0.0, 8.0): "Bot"},
330
+ ... )
331
+
332
+ Tighter line grouping (shorter blocks):
333
+
334
+ >>> srt = self.audio_to_srt(
335
+ ... audio_path=Path("call.wav"),
336
+ ... asr=transcript,
337
+ ... pyannote_token=token,
338
+ ... max_gap_s=0.35, max_chars=70, max_duration_s=6.0,
339
+ ... )
340
+ """
341
+ def _safe_float(x):
342
+ try:
343
+ xf = float(x)
344
+ if math.isfinite(xf):
345
+ return xf
346
+ except Exception:
347
+ pass
348
+ return None
349
+
350
+ if not asr or not asr.get("chunks"):
351
+ raise ValueError(
352
+ "audio_to_srt requires the WhisperX transcript (chunks with words)."
353
+ )
354
+
355
+ import whisperx
356
+ import torch
357
+
358
+ # Use the existing _get_device method
359
+ pipeline_idx, _, _ = self._get_device()
360
+ # Determine device string for WhisperX/pyannote
361
+ if isinstance(pipeline_idx, str):
362
+ # MPS or other special device
363
+ device = pipeline_idx
364
+ elif pipeline_idx >= 0:
365
+ # CUDA device
366
+ device = f"cuda:{pipeline_idx}"
367
+ else:
368
+ # CPU
369
+ device = "cpu"
370
+
371
+ token = pyannote_token or HUGGINGFACEHUB_API_TOKEN
372
+ if not token:
373
+ raise RuntimeError(
374
+ "Missing PYANNOTE token. Set PYANNOTE_AUDIO_AUTH or pass pyannote_token=..."
375
+ )
376
+
377
+ # 1) Run WhisperX diarization on the file
378
+ try:
379
+ diarizer = whisperx.diarize.DiarizationPipeline(
380
+ use_auth_token=token,
381
+ device=device
382
+ )
383
+ except Exception as e:
384
+ if "mps" in str(e).lower() and device == "mps":
385
+ print(f"[WhisperX] MPS diarization failed ({e}), falling back to CPU")
386
+ device = "cpu"
387
+ diarizer = whisperx.diarize.DiarizationPipeline(
388
+ use_auth_token=token,
389
+ device=device
390
+ )
391
+ else:
392
+ raise
393
+
394
+ if speaker_names and len(speaker_names) > 1:
395
+ min_speakers = max(2, len(speaker_names) - 1)
396
+ max_speakers = len(speaker_names) + 1
397
+ diar = diarizer(
398
+ str(audio_path),
399
+ min_speakers=min_speakers,
400
+ max_speakers=max_speakers,
401
+ )
402
+ # 2) Build segments for speaker assignment
403
+ segments = []
404
+ for ch in asr["chunks"]:
405
+ s, e = ch.get("timestamp") or (None, None)
406
+ s = _safe_float(s)
407
+ e = _safe_float(e)
408
+ if s is None or e is None or e <= s:
409
+ continue
410
+ seg_words = []
411
+ for w in ch.get("words") or []:
412
+ ws = _safe_float(w.get("start"))
413
+ we = _safe_float(w.get("end"))
414
+ token = (w.get("word") or "").strip()
415
+ if ws is None or we is None or we <= ws or not token:
416
+ continue
417
+ seg_words.append({"word": token, "start": ws, "end": we})
418
+ segments.append({
419
+ "start": s,
420
+ "end": e,
421
+ "text": ch.get("text") or "",
422
+ "words": seg_words
423
+ })
424
+
425
+ # Assign speakers to words
426
+ assigned = whisperx.assign_word_speakers(diar, {"segments": segments})
427
+ segments = assigned.get("segments", [])
428
+
429
+ # 3) Detect speaker changes and apply corrections
430
+ speaker_segments = self._detect_speaker_segments(segments, min_segment_duration)
431
+
432
+ # Apply manual corrections if provided
433
+ if speaker_corrections:
434
+ speaker_segments = self._apply_speaker_corrections(
435
+ speaker_segments,
436
+ speaker_corrections
437
+ )
438
+
439
+ # 4) Map speakers to names
440
+ sp_map = self._create_speaker_mapping(
441
+ speaker_segments,
442
+ speaker_names
443
+ )
444
+
445
+ # 5) Generate SRT with improved speaker labels
446
+ srt_lines = self._generate_srt_lines(
447
+ speaker_segments,
448
+ sp_map,
449
+ max_gap_s,
450
+ max_chars,
451
+ max_duration_s,
452
+ merge_short_segments
453
+ )
454
+
455
+ srt_text = ("\n".join(srt_lines) + "\n") if srt_lines else ""
456
+
457
+ if output_srt_path:
458
+ Path(output_srt_path).write_text(srt_text, encoding="utf-8")
459
+
460
+ # Cleanup
461
+ gc.collect()
462
+ if device.startswith("cuda"):
463
+ try:
464
+ torch.cuda.empty_cache()
465
+ except Exception:
466
+ pass
467
+
468
+ return srt_text
469
+
470
+ def _detect_speaker_segments(self, segments: list, min_duration: float = 0.5) -> list:
471
+ """
472
+ Detect speaker segments with better change detection.
473
+
474
+ Groups consecutive words by the same speaker and detects speaker changes
475
+ based on gaps in speech or explicit speaker labels.
476
+ """
477
+ if not segments:
478
+ return []
479
+
480
+ speaker_segments = []
481
+ current_speaker = None
482
+ current_start = None
483
+ current_end = None
484
+ current_words = []
485
+ current_text = []
486
+
487
+ for seg in segments:
488
+ for w in seg.get("words") or []:
489
+ word = w.get("word", "").strip()
490
+ if not word:
491
+ continue
492
+
493
+ start = w.get("start")
494
+ end = w.get("end")
495
+ speaker = w.get("speaker")
496
+
497
+ if start is None or end is None:
498
+ continue
499
+
500
+ # Detect speaker change
501
+ speaker_changed = (current_speaker is not None and speaker != current_speaker)
502
+
503
+ # Detect significant gap (might indicate speaker change)
504
+ significant_gap = False
505
+ if current_end is not None:
506
+ gap = start - current_end
507
+ significant_gap = gap > 0.9
508
+
509
+ # Start new segment if speaker changed or significant gap
510
+ if speaker_changed or significant_gap or current_speaker is None:
511
+ # Save current segment if it exists
512
+ if current_words and current_start is not None and current_end is not None:
513
+ duration = current_end - current_start
514
+ if duration >= min_duration or len(current_words) > 3:
515
+ speaker_segments.append({
516
+ "speaker": current_speaker,
517
+ "start": current_start,
518
+ "end": current_end,
519
+ "words": current_words,
520
+ "text": " ".join(current_text)
521
+ })
522
+
523
+ # Start new segment
524
+ current_speaker = speaker
525
+ current_start = start
526
+ current_end = end
527
+ current_words = [w]
528
+ current_text = [word]
529
+ else:
530
+ # Continue current segment
531
+ current_end = max(current_end, end)
532
+ current_words.append(w)
533
+ current_text.append(word)
534
+
535
+ # Don't forget the last segment
536
+ if current_words and current_start is not None and current_end is not None:
537
+ speaker_segments.append({
538
+ "speaker": current_speaker,
539
+ "start": current_start,
540
+ "end": current_end,
541
+ "words": current_words,
542
+ "text": " ".join(current_text)
543
+ })
544
+
545
+ return speaker_segments
546
+
547
+ def _apply_speaker_corrections(self, segments: list, corrections: dict) -> list:
548
+ """
549
+ Apply manual speaker corrections to specific segments.
550
+
551
+ Args:
552
+ segments: List of speaker segments
553
+ corrections: Dict mapping segment index to correct speaker name
554
+ """
555
+ for idx, correction_speaker in corrections.items():
556
+ if 0 <= idx < len(segments):
557
+ segments[idx]["speaker"] = correction_speaker
558
+ print(f"[Speaker Correction] Segment {idx}: -> {correction_speaker}")
559
+
560
+ return segments
561
+
562
+ def _create_speaker_mapping(self, segments: list, speaker_names: list = None) -> dict:
563
+ """
564
+ Create mapping from speaker IDs to names.
565
+
566
+ Improved logic that better handles the initial recording message
567
+ and subsequent speakers.
568
+ """
569
+ # Identify unique speakers by order of appearance
570
+ seen_speakers = []
571
+ for seg in segments:
572
+ sp = seg.get("speaker")
573
+ if sp and sp not in seen_speakers:
574
+ seen_speakers.append(sp)
575
+
576
+ sp_map = {}
577
+
578
+ if speaker_names:
579
+ # Special handling for recordings with initial disclaimer
580
+ # Check if first segment is very early (< 10 seconds) and might be recording
581
+ if segments and segments[0]["start"] < 10:
582
+ first_text = segments[0]["text"].lower()
583
+ # Common recording disclaimer patterns
584
+ recording_patterns = [
585
+ "this call is being recorded",
586
+ "call may be recorded",
587
+ "recording for quality",
588
+ "this conversation is being recorded"
589
+ ]
590
+
591
+ is_recording = any(pattern in first_text for pattern in recording_patterns)
592
+
593
+ if is_recording and len(speaker_names) > len(seen_speakers):
594
+ # First speaker is likely the recording, use first name for it
595
+ if seen_speakers:
596
+ sp_map[seen_speakers[0]] = speaker_names[0] # "Recording" or similar
597
+ # Map remaining speakers starting from second name
598
+ for i, sp in enumerate(seen_speakers[1:], start=1):
599
+ if i < len(speaker_names):
600
+ sp_map[sp] = speaker_names[i]
601
+ else:
602
+ sp_map[sp] = f"Speaker{i}"
603
+ else:
604
+ # Standard mapping
605
+ for i, sp in enumerate(seen_speakers):
606
+ if i < len(speaker_names):
607
+ sp_map[sp] = speaker_names[i]
608
+ else:
609
+ sp_map[sp] = f"Speaker{i+1}"
610
+ else:
611
+ # Standard mapping for normal conversations
612
+ for i, sp in enumerate(seen_speakers):
613
+ if i < len(speaker_names):
614
+ sp_map[sp] = speaker_names[i]
615
+ else:
616
+ sp_map[sp] = f"Speaker{i+1}"
617
+ else:
618
+ # No names provided, use generic labels
619
+ for i, sp in enumerate(seen_speakers):
620
+ sp_map[sp] = f"Speaker{i+1}"
621
+
622
+ # Handle None speaker
623
+ sp_map[None] = "Unknown"
624
+
625
+ return sp_map
626
+
627
+ def _generate_srt_lines(
628
+ self,
629
+ segments: list,
630
+ sp_map: dict,
631
+ max_gap_s: float,
632
+ max_chars: int,
633
+ max_duration_s: float,
634
+ merge_short: bool
635
+ ) -> list:
636
+ """
637
+ Generate SRT lines from speaker segments.
638
+ """
639
+ def _fmt_srt_time(t: float) -> str:
640
+ if t is None or not math.isfinite(t) or t < 0:
641
+ t = 0.0
642
+ ms = int(round(t * 1000.0))
643
+ h, ms = divmod(ms, 3600000)
644
+ m, ms = divmod(ms, 60000)
645
+ s, ms = divmod(ms, 1000)
646
+ return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
647
+
648
+ srt_lines = []
649
+ idx = 1
650
+
651
+ for seg in segments:
652
+ speaker = sp_map.get(seg["speaker"], seg["speaker"] or "Unknown")
653
+ text = seg["text"].strip()
654
+
655
+ if not text:
656
+ continue
657
+
658
+ # Split long segments if needed
659
+ words = seg["words"]
660
+ if len(text) > max_chars or (seg["end"] - seg["start"]) > max_duration_s:
661
+ # Need to split this segment
662
+ sub_segments = self._split_long_segment(
663
+ words,
664
+ max_chars,
665
+ max_duration_s,
666
+ max_gap_s
667
+ )
668
+
669
+ for sub_seg in sub_segments:
670
+ if sub_seg["text"].strip():
671
+ srt_lines.append(
672
+ f"{idx}\n"
673
+ f"{_fmt_srt_time(sub_seg['start'])} --> {_fmt_srt_time(sub_seg['end'])}\n"
674
+ f"{speaker}: {sub_seg['text']}\n"
675
+ )
676
+ idx += 1
677
+ else:
678
+ # Use segment as is
679
+ srt_lines.append(
680
+ f"{idx}\n"
681
+ f"{_fmt_srt_time(seg['start'])} --> {_fmt_srt_time(seg['end'])}\n"
682
+ f"{speaker}: {text}\n"
683
+ )
684
+ idx += 1
685
+
686
+ return srt_lines
687
+
688
+ def _split_long_segment(
689
+ self,
690
+ words: list,
691
+ max_chars: int,
692
+ max_duration: float,
693
+ max_gap: float
694
+ ) -> list:
695
+ """
696
+ Split a long segment into smaller chunks for better readability.
697
+ """
698
+ sub_segments = []
699
+ current_words = []
700
+ current_start = None
701
+ current_end = None
702
+ current_text = []
703
+
704
+ for w in words:
705
+ word = w.get("word", "").strip()
706
+ start = w.get("start")
707
+ end = w.get("end")
708
+
709
+ if not word or start is None or end is None:
710
+ continue
711
+
712
+ # Check if adding this word would exceed limits
713
+ would_exceed_chars = len(" ".join(current_text + [word])) > max_chars
714
+ would_exceed_duration = (
715
+ current_start is not None and
716
+ (end - current_start) > max_duration
717
+ )
718
+
719
+ # Check for natural break point (gap)
720
+ is_natural_break = False
721
+ if current_end is not None:
722
+ gap = start - current_end
723
+ is_natural_break = gap > max_gap
724
+
725
+ if (would_exceed_chars or would_exceed_duration or is_natural_break) and current_text:
726
+ # Save current sub-segment
727
+ sub_segments.append({
728
+ "start": current_start,
729
+ "end": current_end,
730
+ "text": " ".join(current_text)
731
+ })
732
+ # Start new sub-segment
733
+ current_words = [w]
734
+ current_start = start
735
+ current_end = end
736
+ current_text = [word]
737
+ else:
738
+ # Add to current sub-segment
739
+ if current_start is None:
740
+ current_start = start
741
+ current_end = end
742
+ current_words.append(w)
743
+ current_text.append(word)
744
+
745
+ # Don't forget the last sub-segment
746
+ if current_text:
747
+ sub_segments.append({
748
+ "start": current_start,
749
+ "end": current_end,
750
+ "text": " ".join(current_text)
751
+ })
752
+
753
+ return sub_segments
754
+
755
+ def format_timestamp(self, seconds):
756
+ # This helper function takes the total seconds and formats it into hh:mm:ss,ms
757
+ hours, remainder = divmod(int(seconds), 3600)
758
+ minutes, seconds = divmod(remainder, 60)
759
+ milliseconds = int((seconds % 1) * 1000)
760
+ seconds = int(seconds)
761
+ return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
762
+
763
+ def transcript_to_blocks(self, transcript: str) -> list:
764
+ """
765
+ Convert a transcript to blocks.
766
+ """
767
+ blocks = []
768
+ for i, chunk in enumerate(transcript['chunks'], start=1):
769
+ current_window = {}
770
+ start, end = chunk['timestamp']
771
+ if start is None or end is None:
772
+ print(f"Warning: Missing timestamp for chunk {i}, skipping this chunk.")
773
+ continue
774
+
775
+ start_srt = self.format_timestamp(start)
776
+ end_srt = self.format_timestamp(end)
777
+ text = chunk['text'].replace("\n", " ") # Replace newlines in text with spaces
778
+ current_window['id'] = i
779
+ current_window['start_time'] = start_srt
780
+ current_window['end_time'] = end_srt
781
+ current_window['text'] = text
782
+ blocks.append(current_window)
783
+ return blocks
784
+
785
+ def chunk_text(self, text, chunk_size, tokenizer):
786
+ # Tokenize the text and get the number of tokens
787
+ tokens = tokenizer.tokenize(text)
788
+ # Split the tokens into chunks
789
+ for i in range(0, len(tokens), chunk_size):
790
+ yield tokenizer.convert_tokens_to_string(
791
+ tokens[i:i+chunk_size]
792
+ )
793
+
794
+ def extract_audio(
795
+ self,
796
+ video_path: Path,
797
+ audio_path: Path,
798
+ compress_speed: bool = False,
799
+ output_path: Optional[Path] = None,
800
+ speed_factor: float = 1.5
801
+ ):
802
+ """
803
+ Extract audio from video. Prefer WAV 16k mono for Whisper.
804
+ """
805
+ video_path = Path(video_path)
806
+ audio_path = Path(audio_path)
807
+
808
+ if audio_path.exists():
809
+ print(f"Audio already extracted: {audio_path}")
810
+ return
811
+
812
+ # Extract as WAV 16k mono PCM
813
+ print(f"Extracting audio (16k mono WAV) to: {audio_path}")
814
+ from moviepy import VideoFileClip
815
+ from pydub import AudioSegment
816
+ clip = VideoFileClip(str(video_path))
817
+ if not clip.audio:
818
+ print("No audio found in video.")
819
+ clip.close()
820
+ return
821
+
822
+ # moviepy/ffmpeg: pcm_s16le, 16k, mono
823
+ # Ensure audio_path has .wav
824
+ if audio_path.suffix.lower() != ".wav":
825
+ audio_path = audio_path.with_suffix(".wav")
826
+
827
+ clip.audio.write_audiofile(
828
+ str(audio_path),
829
+ fps=16000,
830
+ nbytes=2,
831
+ codec="pcm_s16le",
832
+ ffmpeg_params=["-ac", "1"]
833
+ )
834
+ clip.audio.close()
835
+ clip.close()
836
+
837
+ # Optional speed compression (still output WAV @16k mono)
838
+ if compress_speed:
839
+ print(f"Compressing audio speed by factor: {speed_factor}")
840
+ audio = AudioSegment.from_file(audio_path)
841
+ sped = audio._spawn(audio.raw_data, overrides={"frame_rate": int(audio.frame_rate * speed_factor)})
842
+ sped = sped.set_frame_rate(16000).set_channels(1).set_sample_width(2)
843
+ sped.export(str(output_path or audio_path), format="wav")
844
+ print(f"Compressed audio saved to: {output_path or audio_path}")
845
+ else:
846
+ print(f"Audio extracted: {audio_path}")
847
+
848
+ def ensure_wav_16k_mono(self, src_path: Path) -> Path:
849
+ """
850
+ Ensure `src_path` is a 16 kHz mono PCM WAV. Returns the WAV path.
851
+ - If src is not a .wav, write <stem>.wav
852
+ - If src is already .wav, write <stem>.16k.wav to avoid in-place overwrite
853
+ """
854
+ from pydub import AudioSegment
855
+ src_path = Path(src_path)
856
+ if src_path.suffix.lower() == ".wav":
857
+ out_path = src_path.with_name(f"{src_path.stem}.16k.wav")
858
+ else:
859
+ out_path = src_path.with_suffix(".wav")
860
+
861
+ # Always (re)encode to guarantee 16k mono PCM s16le
862
+ audio = AudioSegment.from_file(src_path)
863
+ audio = (
864
+ audio.set_frame_rate(16000) # 16 kHz
865
+ .set_channels(1) # mono
866
+ .set_sample_width(2) # s16le
867
+ )
868
+ audio.export(str(out_path), format="wav")
869
+ print(f"Transcoded to 16k mono WAV: {out_path}")
870
+ return out_path
871
+
872
+ def _get_whisperx_name(self, language: str = 'en', model_size: str = 'small', version: str = 'v3'):
873
+ """
874
+ Get the appropriate WhisperX model name based on language and size.
875
+
876
+ WhisperX model naming conventions:
877
+ - English-only models: "tiny.en", "base.en", "small.en", "medium.en"
878
+ - Multilingual models: "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "large-v3", "turbo"
879
+
880
+ Args:
881
+ language: Language code (e.g., "en", "es", "fr")
882
+ model_size: Model size ("tiny", "base", "small", "medium", "large", "turbo")
883
+ model_name: Explicit model name to use (overrides size selection)
884
+
885
+ Returns:
886
+ tuple: (model_name, detected_language)
887
+ """
888
+ if language.lower() == 'en' and model_size.lower() in {'tiny', 'base', 'small', 'medium'}:
889
+ return f"{model_size}.en"
890
+ elif model_size.lower() in {'tiny', 'base', 'small', 'medium'}:
891
+ return f"{model_size}"
892
+ else:
893
+ return f"{model_size}-{version}"
894
+
895
+ def get_whisperx_transcript(
896
+ self,
897
+ audio_path: Path,
898
+ language: str = "en",
899
+ model_name: str = None,
900
+ batch_size: int = 8,
901
+ compute_type_gpu: str = "float16",
902
+ compute_type_cpu: str = "int8"
903
+ ):
904
+ """
905
+ WhisperX-based transcription with word-level timestamps.
906
+ Returns:
907
+ {
908
+ "text": "...",
909
+ "chunks": [
910
+ {
911
+ "text": "...",
912
+ "timestamp": (start, end),
913
+ "words": [{"word":"...", "start":..., "end":...}, ...]
914
+ },
915
+ ...
916
+ ],
917
+ "language": "en"
918
+ }
919
+ """
920
+ def _safe_float(x):
921
+ try:
922
+ xf = float(x)
923
+ if math.isfinite(xf):
924
+ return xf
925
+ except Exception:
926
+ pass
927
+ return None
928
+
929
+ # Lazy load whisperx (only when needed)
930
+ import whisperx
931
+ import torch
932
+
933
+ # Use the existing _get_device method
934
+ pipeline_idx, _, _ = self._get_device()
935
+ # Determine device string for WhisperX
936
+ if isinstance(pipeline_idx, str):
937
+ # MPS or other special device
938
+ device = pipeline_idx
939
+ elif pipeline_idx >= 0:
940
+ # CUDA device
941
+ device = "cuda"
942
+ else:
943
+ # CPU
944
+ device = "cpu"
945
+
946
+ # Select compute type based on device
947
+ if device.startswith("cuda"):
948
+ compute_type = compute_type_gpu
949
+ elif device == "mps":
950
+ # MPS typically works better with float32
951
+ compute_type = "float32"
952
+ else:
953
+ compute_type = compute_type_cpu
954
+
955
+ # Model selection
956
+ lang = (language or self._language).lower()
957
+
958
+ if model_name:
959
+ model_id = model_name
960
+ else:
961
+ model_id = self._get_whisperx_name(lang, self._model_size)
962
+
963
+ # 1) ASR
964
+ model = whisperx.load_model(
965
+ model_id,
966
+ device=device,
967
+ compute_type=compute_type,
968
+ language=language
969
+ )
970
+ audio = whisperx.load_audio(str(audio_path))
971
+ asr_result = model.transcribe(audio, batch_size=batch_size)
972
+ lang = asr_result.get("language", language)
973
+ segs = asr_result.get("segments", []) or []
974
+
975
+ # 2) Alignment → precise word times
976
+ align_model, align_meta = whisperx.load_align_model(
977
+ language_code=asr_result.get("language", language), device=device
978
+ )
979
+ aligned = whisperx.align(
980
+ segs,
981
+ align_model,
982
+ align_meta,
983
+ audio,
984
+ device=device,
985
+ return_char_alignments=False
986
+ )
987
+
988
+ # build the return payload in your existing schema
989
+ chunks = []
990
+ full_text_parts = []
991
+ for seg in aligned.get("segments", []):
992
+ s = _safe_float(seg.get("start"))
993
+ e = _safe_float(seg.get("end"))
994
+ if s is None or e is None or e <= s:
995
+ continue
996
+ text = (seg.get("text") or "").strip()
997
+ words_out = []
998
+ for w in seg.get("words") or []:
999
+ ws = _safe_float(w.get("start"))
1000
+ we = _safe_float(w.get("end"))
1001
+ token = (w.get("word") or "").strip()
1002
+ if ws is None or we is None or we <= ws or not token:
1003
+ continue
1004
+ words_out.append({"word": token, "start": ws, "end": we})
1005
+ chunks.append({"text": text, "timestamp": (s, e), "words": words_out})
1006
+ if text:
1007
+ full_text_parts.append(text)
1008
+
1009
+ # Cleanup
1010
+ del model
1011
+ del align_model
1012
+ gc.collect()
1013
+ try:
1014
+ if device.startswith("cuda"):
1015
+ torch.cuda.empty_cache()
1016
+ except Exception:
1017
+ pass
1018
+
1019
+ return {"text": " ".join(full_text_parts).strip(), "chunks": chunks, "language": lang}
1020
+
1021
+ def get_whisper_transcript(
1022
+ self,
1023
+ audio_path: Path,
1024
+ chunk_length: int = 30,
1025
+ word_timestamps: bool = False,
1026
+ manual_chunk: bool = True, # New parameter to enable manual chunking
1027
+ max_chunk_duration: int = 60 # Maximum seconds per chunk for GPU processing
1028
+ ):
1029
+ """
1030
+ Enhanced Whisper transcription with manual chunking for GPU memory management.
1031
+
1032
+ The key insight: We process smaller audio segments independently on GPU,
1033
+ then merge results with corrected timestamps based on each chunk's offset.
1034
+ """
1035
+ import soundfile
1036
+ # Model selection
1037
+ lang = (self._language or "en").lower()
1038
+ if self._model_name in (None, "", "whisper", "openai/whisper"):
1039
+ size = (self._model_size or "small").lower()
1040
+ if lang == "en" and size in {"tiny", "base", "small", "medium"}:
1041
+ model_id = f"openai/whisper-{size}.en"
1042
+ elif size == "turbo":
1043
+ model_id = "openai/whisper-large-v3-turbo"
1044
+ else:
1045
+ model_id = "openai/whisper-large-v3"
1046
+ else:
1047
+ model_id = self._model_name
1048
+
1049
+ # Load audio once
1050
+ if not (audio_path.exists() and audio_path.stat().st_size > 0):
1051
+ return None
1052
+
1053
+ wav, sr = soundfile.read(str(audio_path), always_2d=False)
1054
+ if wav.ndim == 2:
1055
+ wav = wav.mean(axis=1)
1056
+ wav = wav.astype(np.float32, copy=False)
1057
+
1058
+ total_duration = len(wav) / float(sr)
1059
+ print(f"[Whisper] Total audio duration: {total_duration:.2f} seconds")
1060
+
1061
+ # Device configuration
1062
+ device_idx, dev, torch_dtype = self._get_device()
1063
+ # Special handling for MPS or other non-standard devices
1064
+ if isinstance(device_idx, str):
1065
+ # MPS or other special case - treat as CPU for pipeline purposes
1066
+ pipeline_device_idx = -1
1067
+ print(
1068
+ f"[Whisper] Using {device_idx} device (will use CPU pipeline mode)"
1069
+ )
1070
+ else:
1071
+ pipeline_device_idx = device_idx
1072
+
1073
+ # Determine if we need manual chunking
1074
+ # Rule of thumb: whisper-medium needs ~6GB for 60s of audio
1075
+ needs_manual_chunk = (
1076
+ manual_chunk and
1077
+ isinstance(device_idx, int) and device_idx >= 0 and # Using GPU
1078
+ total_duration > max_chunk_duration # Audio is long
1079
+ )
1080
+
1081
+ print('[Whisper] Using model:', model_id, 'Chunking needed: ', needs_manual_chunk)
1082
+
1083
+ if needs_manual_chunk:
1084
+ print(
1085
+ f"[Whisper] Using manual chunking strategy (chunks of {max_chunk_duration}s)"
1086
+ )
1087
+ return self._process_chunks(
1088
+ wav, sr, model_id, lang, pipeline_device_idx, dev, torch_dtype,
1089
+ max_chunk_duration, word_timestamps
1090
+ )
1091
+ else:
1092
+ # Use the standard pipeline for short audio or CPU processing
1093
+ return self._process_pipeline(
1094
+ wav, sr, model_id, lang, pipeline_device_idx, dev, torch_dtype,
1095
+ chunk_length, word_timestamps
1096
+ )
1097
+
1098
+ def _process_pipeline(
1099
+ self,
1100
+ wav: np.ndarray,
1101
+ sr: int,
1102
+ model_id: str,
1103
+ lang: str,
1104
+ device_idx: int,
1105
+ torch_dev: str,
1106
+ torch_dtype,
1107
+ chunk_length: int,
1108
+ word_timestamps: bool
1109
+ ):
1110
+ """Use HF pipeline's built-in chunking & timestamping."""
1111
+ # Lazy load transformers components (only when needed)
1112
+ from transformers import (
1113
+ pipeline,
1114
+ WhisperForConditionalGeneration,
1115
+ WhisperProcessor
1116
+ )
1117
+
1118
+ is_english_only = (
1119
+ model_id.endswith('.en') or
1120
+ '-en' in model_id.split('/')[-1] or
1121
+ model_id.endswith('-en')
1122
+ )
1123
+
1124
+ model = WhisperForConditionalGeneration.from_pretrained(
1125
+ model_id,
1126
+ attn_implementation="eager", # silence SDPA warning + future-proof
1127
+ torch_dtype=torch_dtype,
1128
+ low_cpu_mem_usage=True,
1129
+ ).to(torch_dev)
1130
+ processor = WhisperProcessor.from_pretrained(model_id)
1131
+
1132
+ chunk_length = int(chunk_length) if chunk_length else 30
1133
+ stride = 6 if chunk_length >= 8 else max(1, chunk_length // 5)
1134
+
1135
+ asr = pipeline(
1136
+ task="automatic-speech-recognition",
1137
+ model=model,
1138
+ tokenizer=processor.tokenizer,
1139
+ feature_extractor=processor.feature_extractor,
1140
+ device=device_idx if device_idx >= 0 else -1,
1141
+ torch_dtype=torch_dtype,
1142
+ chunk_length_s=chunk_length,
1143
+ stride_length_s=stride,
1144
+ batch_size=1
1145
+ )
1146
+
1147
+ # Timestamp mode
1148
+ ts_mode = "word" if word_timestamps else True
1149
+
1150
+ generate_kwargs = {
1151
+ "temperature": 0.0,
1152
+ "compression_ratio_threshold": 2.4,
1153
+ "logprob_threshold": -1.0,
1154
+ "no_speech_threshold": 0.6,
1155
+ }
1156
+ # Language forcing only when not English-only
1157
+ if not is_english_only and lang:
1158
+ try:
1159
+ generate_kwargs["language"] = lang
1160
+ generate_kwargs["task"] = "transcribe"
1161
+ except Exception:
1162
+ pass
1163
+
1164
+ # Let the pipeline handle attention_mask/padding
1165
+ out = asr(
1166
+ {"raw": wav, "sampling_rate": sr},
1167
+ return_timestamps=ts_mode,
1168
+ generate_kwargs=generate_kwargs,
1169
+ )
1170
+
1171
+ chunks = out.get("chunks", [])
1172
+ # normalize to your return shape
1173
+ out['text'] = out.get("text") or " ".join(c["text"] for c in chunks)
1174
+ return out
1175
+
1176
+ def _process_chunks(
1177
+ self,
1178
+ wav: np.ndarray,
1179
+ sr: int,
1180
+ model_id: str,
1181
+ lang: str,
1182
+ device_idx: int,
1183
+ torch_dev: str,
1184
+ torch_dtype,
1185
+ max_chunk_duration: int,
1186
+ word_timestamps: bool,
1187
+ chunk_length: int = 60
1188
+ ):
1189
+ """
1190
+ Robust audio chunking with better error handling and memory management.
1191
+
1192
+ This version addresses several key issues:
1193
+ 1. The 'input_ids' error by properly configuring the pipeline
1194
+ 2. The audio format issue in fallbacks
1195
+ 3. Memory management for smaller GPUs
1196
+ 4. Chunk processing stability
1197
+ """
1198
+ # Lazy load transformers components (only when needed)
1199
+ from transformers import (
1200
+ pipeline,
1201
+ WhisperForConditionalGeneration,
1202
+ WhisperProcessor
1203
+ )
1204
+
1205
+ # For whisper-small on a 5.6GB GPU, we can use slightly larger chunks than medium
1206
+ # whisper-small uses ~1.5GB, leaving ~4GB for processing
1207
+ actual_chunk_duration = min(45, max_chunk_duration) # Can handle 45s chunks with small
1208
+
1209
+ # Set environment variable for better memory management
1210
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
1211
+
1212
+ # English-only models end with '.en' or contain '-en' in their name
1213
+ is_english_only = (
1214
+ model_id.endswith('.en') or
1215
+ '-en' in model_id.split('/')[-1] or
1216
+ model_id.endswith('-en')
1217
+ )
1218
+
1219
+ print(f"[Whisper] Model type: {'English-only' if is_english_only else 'Multilingual'}")
1220
+ print(f"[Whisper] Using model: {model_id}")
1221
+
1222
+ chunk_samples = actual_chunk_duration * sr
1223
+ overlap_duration = 2 # 2 seconds overlap to avoid cutting words
1224
+ overlap_samples = overlap_duration * sr
1225
+
1226
+ print(f"[Whisper] Processing {len(wav)/sr:.1f}s audio in {actual_chunk_duration}s chunks")
1227
+
1228
+ all_results = []
1229
+ offset = 0
1230
+ chunk_idx = 0
1231
+
1232
+ # Load model once for all chunks (whisper-small fits comfortably in memory)
1233
+ print(f"[Whisper] Loading {model_id} model...")
1234
+ model = WhisperForConditionalGeneration.from_pretrained(
1235
+ model_id,
1236
+ attn_implementation="eager", # <= fixes SDPA warning
1237
+ torch_dtype=torch_dtype,
1238
+ low_cpu_mem_usage=True,
1239
+ use_safetensors=True,
1240
+ ).to(torch_dev)
1241
+ processor = WhisperProcessor.from_pretrained(model_id)
1242
+
1243
+ # Base generation kwargs - we'll be careful about what we pass
1244
+ base_generate_kwargs = {
1245
+ "temperature": 0.0, # Deterministic to reduce hallucinations
1246
+ "compression_ratio_threshold": 2.4, # Detect repetitive text
1247
+ "logprob_threshold": -1.0,
1248
+ "no_speech_threshold": 0.6,
1249
+ }
1250
+
1251
+ # Only add language forcing if it's properly supported
1252
+ if not is_english_only:
1253
+ try:
1254
+ forced_ids = processor.get_decoder_prompt_ids(
1255
+ language=lang,
1256
+ task="transcribe"
1257
+ )
1258
+ if forced_ids:
1259
+ base_generate_kwargs["language"] = lang
1260
+ base_generate_kwargs["task"] = "transcribe"
1261
+ # Note: We don't pass forced_decoder_ids directly as it can cause issues
1262
+ except Exception:
1263
+ # If the processor doesn't support this, that's fine
1264
+ pass
1265
+
1266
+ while offset < len(wav):
1267
+ # Extract chunk
1268
+ end_sample = min(offset + chunk_samples, len(wav))
1269
+ chunk_wav = wav[offset:end_sample]
1270
+
1271
+ # Calculate timing for this chunk
1272
+ time_offset = offset / float(sr)
1273
+ chunk_duration = len(chunk_wav) / float(sr)
1274
+
1275
+ print(f"[Whisper] Processing chunk {chunk_idx + 1} "
1276
+ f"({time_offset:.1f}s - {time_offset + chunk_duration:.1f}s)")
1277
+
1278
+ # Process this chunk with careful error handling
1279
+ chunk_processed = False
1280
+ attempts = [
1281
+ ("standard", word_timestamps),
1282
+ ("chunk_timestamps", False), # Fallback to chunk timestamps
1283
+ ("basic", False) # Most basic mode
1284
+ ]
1285
+ chunk_length = int(chunk_length) if chunk_length else 30
1286
+ stride = 6 if chunk_length >= 8 else max(1, chunk_length // 5)
1287
+
1288
+ for attempt_name, use_word_timestamps in attempts:
1289
+ if chunk_processed:
1290
+ break
1291
+
1292
+ try:
1293
+ # Create a fresh pipeline for each chunk to avoid state issues
1294
+ # This is important for avoiding the 'input_ids' error
1295
+ asr = pipeline(
1296
+ task="automatic-speech-recognition",
1297
+ model=model,
1298
+ tokenizer=processor.tokenizer,
1299
+ feature_extractor=processor.feature_extractor,
1300
+ device=device_idx if device_idx >= 0 else -1,
1301
+ chunk_length_s=chunk_length,
1302
+ stride_length_s=stride,
1303
+ batch_size=1,
1304
+ torch_dtype=torch_dtype,
1305
+ )
1306
+
1307
+ # Prepare audio input with the CORRECT format
1308
+ # This is crucial - the pipeline expects "raw" not "array"
1309
+ audio_input = {
1310
+ "raw": chunk_wav,
1311
+ "sampling_rate": sr
1312
+ }
1313
+
1314
+ # Determine timestamp mode based on current attempt
1315
+ if use_word_timestamps:
1316
+ timestamp_param = "word"
1317
+ else:
1318
+ timestamp_param = True # Chunk-level timestamps
1319
+
1320
+ # Use a clean copy of generate_kwargs for each attempt
1321
+ # This prevents accumulation of incompatible parameters
1322
+ generate_kwargs = base_generate_kwargs.copy()
1323
+
1324
+ # Process the chunk
1325
+ chunk_result = asr(
1326
+ audio_input,
1327
+ return_timestamps=timestamp_param,
1328
+ generate_kwargs=generate_kwargs
1329
+ )
1330
+
1331
+ # Successfully processed - now handle the results
1332
+ if chunk_result and "chunks" in chunk_result:
1333
+ for item in chunk_result["chunks"]:
1334
+ # Adjust timestamps for this chunk's position
1335
+ if "timestamp" in item and item["timestamp"]:
1336
+ start, end = item["timestamp"]
1337
+ if start is not None:
1338
+ start += time_offset
1339
+ if end is not None:
1340
+ end += time_offset
1341
+ item["timestamp"] = (start, end)
1342
+
1343
+ # Add metadata for merging
1344
+ item["_chunk_idx"] = chunk_idx
1345
+ item["_is_word"] = use_word_timestamps
1346
+
1347
+ all_results.extend(chunk_result["chunks"])
1348
+ print(f" ✓ Chunk {chunk_idx + 1}: {len(chunk_result['chunks'])} items "
1349
+ f"(mode: {attempt_name})")
1350
+ chunk_processed = True
1351
+
1352
+ # Clean up the pipeline to free memory
1353
+ del asr
1354
+ gc.collect()
1355
+ if device_idx >= 0:
1356
+ torch.cuda.empty_cache()
1357
+
1358
+ except Exception as e:
1359
+ error_msg = str(e)
1360
+ print(f" ✗ Attempt '{attempt_name}' failed: {error_msg[:100]}")
1361
+
1362
+ # Clean up on error
1363
+ if 'asr' in locals():
1364
+ del asr
1365
+ gc.collect()
1366
+ if device_idx >= 0:
1367
+ torch.cuda.empty_cache()
1368
+
1369
+ # Continue to next attempt
1370
+ continue
1371
+
1372
+ if not chunk_processed:
1373
+ print(f" âš  Chunk {chunk_idx + 1} could not be processed, skipping")
1374
+
1375
+ # Move to next chunk
1376
+ if end_sample < len(wav):
1377
+ offset += chunk_samples - overlap_samples
1378
+ else:
1379
+ break
1380
+
1381
+ chunk_idx += 1
1382
+
1383
+ # Clean up model after all chunks
1384
+ del model
1385
+ del processor
1386
+ gc.collect()
1387
+ if device_idx >= 0:
1388
+ torch.cuda.empty_cache()
1389
+
1390
+ # Merge results based on whether we got word or chunk timestamps
1391
+ # Check what we actually got (might be mixed if some chunks fell back)
1392
+ has_word_timestamps = any(item.get("_is_word", False) for item in all_results)
1393
+
1394
+ if has_word_timestamps:
1395
+ print("[Whisper] Merging word-level timestamps...")
1396
+ final_chunks = self._merge_word_chunks(all_results, overlap_duration)
1397
+ else:
1398
+ print("[Whisper] Merging chunk-level timestamps...")
1399
+ final_chunks = self._merge_overlapping_chunks(all_results, overlap_duration)
1400
+
1401
+ # Clean the results to remove any garbage/hallucinations
1402
+ cleaned_chunks = []
1403
+ for chunk in final_chunks:
1404
+ text = chunk.get("text", "").strip()
1405
+
1406
+ # Filter out common hallucination patterns
1407
+ if not text:
1408
+ continue
1409
+ if len(set(text)) < 3 and len(text) > 10: # Repetitive characters
1410
+ continue
1411
+ if text.count("$") > len(text) * 0.5: # Too many special characters
1412
+ continue
1413
+ if text.count("�") > 0: # Unicode errors
1414
+ continue
1415
+
1416
+ chunk["text"] = text
1417
+ cleaned_chunks.append(chunk)
1418
+
1419
+ # Build the final result
1420
+ result = {
1421
+ "chunks": cleaned_chunks,
1422
+ "text": " ".join(ch["text"] for ch in cleaned_chunks),
1423
+ "word_timestamps": has_word_timestamps
1424
+ }
1425
+
1426
+ print(f"[Whisper] Transcription complete: {len(cleaned_chunks)} segments, "
1427
+ f"{len(result['text'].split())} words")
1428
+
1429
+ return result
1430
+
1431
+ def _merge_overlapping_chunks(self, chunks: List[dict], overlap_duration: float) -> List[dict]:
1432
+ """
1433
+ Intelligently merge chunks that might have overlapping content.
1434
+
1435
+ When we process overlapping audio segments, we might get duplicate
1436
+ transcriptions at the boundaries. This function:
1437
+ 1. Detects potential duplicates based on timestamp overlap
1438
+ 2. Keeps the best version (usually from the chunk where it's not at the edge)
1439
+ 3. Maintains temporal order
1440
+ """
1441
+ if not chunks:
1442
+ return []
1443
+
1444
+ # Sort by start time
1445
+ chunks.sort(key=lambda x: x.get("timestamp", (0,))[0] or 0)
1446
+
1447
+ merged = []
1448
+ for chunk in chunks:
1449
+ if not chunk.get("text", "").strip():
1450
+ continue
1451
+
1452
+ timestamp = chunk.get("timestamp", (None, None))
1453
+ if not timestamp or timestamp[0] is None:
1454
+ continue
1455
+
1456
+ # Check if this chunk overlaps significantly with the last merged chunk
1457
+ if merged:
1458
+ last = merged[-1]
1459
+ last_ts = last.get("timestamp", (None, None))
1460
+
1461
+ if last_ts and last_ts[1] and timestamp[0]:
1462
+ # If timestamps overlap significantly
1463
+ overlap = last_ts[1] - timestamp[0]
1464
+ if overlap > 0.5: # More than 0.5 second overlap
1465
+ # Compare text similarity to detect duplicates
1466
+ last_text = last.get("text", "").strip().lower()
1467
+ curr_text = chunk.get("text", "").strip().lower()
1468
+
1469
+ # Simple duplicate detection
1470
+ if last_text == curr_text:
1471
+ # Skip this duplicate
1472
+ continue
1473
+
1474
+ # If texts are very similar (e.g., one is subset of another)
1475
+ if len(last_text) > 10 and len(curr_text) > 10:
1476
+ if last_text in curr_text or curr_text in last_text:
1477
+ # Keep the longer version
1478
+ if len(curr_text) > len(last_text):
1479
+ merged[-1] = chunk
1480
+ continue
1481
+
1482
+ merged.append(chunk)
1483
+
1484
+ return merged
1485
+
1486
+ def _merge_word_chunks(self, chunks: List[dict], overlap_duration: float) -> List[dict]:
1487
+ """
1488
+ Special merging logic for word-level timestamps.
1489
+
1490
+ Word-level chunks need more careful handling because:
1491
+ 1. Words at boundaries might appear in multiple chunks
1492
+ 2. Timestamp precision is more important
1493
+ 3. We need to maintain word order exactly
1494
+ """
1495
+ if not chunks:
1496
+ return []
1497
+
1498
+ # Sort by start timestamp
1499
+ chunks.sort(key=lambda x: (x.get("timestamp", (0,))[0] or 0, x.get("_chunk_idx", 0)))
1500
+
1501
+ merged = []
1502
+ seen_words = set() # Track (word, approximate_time) to avoid duplicates
1503
+
1504
+ for chunk in chunks:
1505
+ word = chunk.get("text", "").strip()
1506
+ if not word:
1507
+ continue
1508
+
1509
+ timestamp = chunk.get("timestamp", (None, None))
1510
+ if not timestamp or timestamp[0] is None:
1511
+ continue
1512
+
1513
+ # Create a key for duplicate detection
1514
+ # Round timestamp to nearest 0.1s for fuzzy matching
1515
+ time_key = round(timestamp[0], 1)
1516
+ word_key = (word.lower(), time_key)
1517
+
1518
+ # Skip if we've seen this word at approximately this time
1519
+ if word_key in seen_words:
1520
+ continue
1521
+
1522
+ seen_words.add(word_key)
1523
+ merged.append(chunk)
1524
+
1525
+ return merged
1526
+
1527
+ def clear_cuda(self):
1528
+ """
1529
+ Clear CUDA cache and free all GPU memory used by this loader.
1530
+
1531
+ This method:
1532
+ 1. Deletes the summarizer pipeline if loaded
1533
+ 2. Forces garbage collection
1534
+ 3. Clears PyTorch CUDA cache
1535
+
1536
+ Call this method when done processing to free VRAM for other tasks.
1537
+ """
1538
+ freed_items = []
1539
+
1540
+ # Free summarizer if it was loaded
1541
+ if self._summarizer is not None:
1542
+ del self.summarizer # Uses the deleter which handles cleanup
1543
+ freed_items.append("summarizer")
1544
+
1545
+ # Force garbage collection
1546
+ gc.collect()
1547
+
1548
+ # Clear CUDA cache if on GPU
1549
+ device = getattr(self, '_summarizer_device', None)
1550
+ if device and (device.startswith('cuda') or isinstance(device, int) and device >= 0):
1551
+ try:
1552
+ torch.cuda.empty_cache()
1553
+ freed_items.append("CUDA cache")
1554
+ except Exception as e:
1555
+ print(f"[ParrotBot] Warning: Failed to clear CUDA cache: {e}")
1556
+
1557
+ if freed_items:
1558
+ print(f"[ParrotBot] 🧹 Cleared: {', '.join(freed_items)}")
1559
+ else:
1560
+ print("[ParrotBot] 🧹 No GPU resources to clear")
1561
+
1562
+ @abstractmethod
1563
+ async def _load(self, source: str, **kwargs) -> List[Document]:
1564
+ pass
1565
+
1566
+ @abstractmethod
1567
+ async def load_video(self, url: str, video_title: str, transcript: str) -> list:
1568
+ pass