transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,729 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_lasr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from collections.abc import Callable
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+
29
+ from ...activations import ACT2FN
30
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
31
+ from ...masking_utils import create_bidirectional_mask
32
+ from ...modeling_layers import GradientCheckpointingLayer
33
+ from ...modeling_outputs import BaseModelOutput, CausalLMOutput
34
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
37
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
+ from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
40
+
41
+
42
+ class LasrEncoderSubsampling(nn.Module):
43
+ def __init__(self, config: LasrEncoderConfig):
44
+ super().__init__()
45
+ self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
46
+ self.conv_0 = nn.Conv1d(
47
+ config.hidden_size,
48
+ config.hidden_size,
49
+ kernel_size=config.subsampling_conv_kernel_size,
50
+ stride=config.subsampling_conv_stride,
51
+ )
52
+ self.conv_1 = nn.Conv1d(
53
+ config.hidden_size,
54
+ config.subsampling_conv_channels,
55
+ kernel_size=config.subsampling_conv_kernel_size,
56
+ stride=config.subsampling_conv_stride,
57
+ )
58
+ self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
59
+ self.act_fn = nn.ReLU()
60
+
61
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
62
+ hidden_states = self.act_fn(self.dense_0(input_features))
63
+ hidden_states = hidden_states.transpose(1, 2)
64
+ hidden_states = self.act_fn(self.conv_0(hidden_states))
65
+ hidden_states = self.act_fn(self.conv_1(hidden_states))
66
+ hidden_states = hidden_states.transpose(1, 2)
67
+ return self.dense_1(hidden_states)
68
+
69
+
70
+ class LasrEncoderRotaryEmbedding(nn.Module):
71
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
72
+
73
+ def __init__(self, config: LasrEncoderConfig, device=None):
74
+ super().__init__()
75
+ self.max_seq_len_cached = config.max_position_embeddings
76
+ self.original_max_seq_len = config.max_position_embeddings
77
+
78
+ self.config = config
79
+
80
+ self.rope_type = self.config.rope_parameters["rope_type"]
81
+ rope_init_fn: Callable = self.compute_default_rope_parameters
82
+ if self.rope_type != "default":
83
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
84
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
85
+
86
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
87
+ self.original_inv_freq = inv_freq
88
+
89
+ @staticmethod
90
+ def compute_default_rope_parameters(
91
+ config: Optional[LasrEncoderConfig] = None,
92
+ device: Optional["torch.device"] = None,
93
+ seq_len: Optional[int] = None,
94
+ ) -> tuple["torch.Tensor", float]:
95
+ """
96
+ Computes the inverse frequencies according to the original RoPE implementation
97
+ Args:
98
+ config ([`~transformers.PreTrainedConfig`]):
99
+ The model configuration.
100
+ device (`torch.device`):
101
+ The device to use for initialization of the inverse frequencies.
102
+ seq_len (`int`, *optional*):
103
+ The current sequence length. Unused for this type of RoPE.
104
+ Returns:
105
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
106
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
107
+ """
108
+ base = config.rope_parameters["rope_theta"]
109
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
110
+
111
+ attention_factor = 1.0 # Unused in this type of RoPE
112
+
113
+ # Compute the inverse frequencies
114
+ inv_freq = 1.0 / (
115
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
116
+ )
117
+ return inv_freq, attention_factor
118
+
119
+ @torch.no_grad()
120
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
121
+ def forward(self, x, position_ids):
122
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
123
+ position_ids_expanded = position_ids[:, None, :].float()
124
+
125
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
126
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
127
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
128
+ emb = torch.cat((freqs, freqs), dim=-1)
129
+ cos = emb.cos() * self.attention_scaling
130
+ sin = emb.sin() * self.attention_scaling
131
+
132
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
133
+
134
+
135
+ def rotate_half(x):
136
+ """Rotates half the hidden dims of the input."""
137
+ x1 = x[..., : x.shape[-1] // 2]
138
+ x2 = x[..., x.shape[-1] // 2 :]
139
+ return torch.cat((-x2, x1), dim=-1)
140
+
141
+
142
+ @use_kernel_func_from_hub("rotary_pos_emb")
143
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
144
+ """Applies Rotary Position Embedding to the query and key tensors.
145
+
146
+ Args:
147
+ q (`torch.Tensor`): The query tensor.
148
+ k (`torch.Tensor`): The key tensor.
149
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
150
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
151
+ position_ids (`torch.Tensor`, *optional*):
152
+ Deprecated and unused.
153
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
154
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
155
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
156
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
157
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
158
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
159
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
160
+ Returns:
161
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
162
+ """
163
+ cos = cos.unsqueeze(unsqueeze_dim)
164
+ sin = sin.unsqueeze(unsqueeze_dim)
165
+ q_embed = (q * cos) + (rotate_half(q) * sin)
166
+ k_embed = (k * cos) + (rotate_half(k) * sin)
167
+ return q_embed, k_embed
168
+
169
+
170
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
171
+ """
172
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
173
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
174
+ """
175
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
176
+ if n_rep == 1:
177
+ return hidden_states
178
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
179
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
180
+
181
+
182
+ def eager_attention_forward(
183
+ module: nn.Module,
184
+ query: torch.Tensor,
185
+ key: torch.Tensor,
186
+ value: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor],
188
+ scaling: float,
189
+ dropout: float = 0.0,
190
+ **kwargs: Unpack[TransformersKwargs],
191
+ ):
192
+ key_states = repeat_kv(key, module.num_key_value_groups)
193
+ value_states = repeat_kv(value, module.num_key_value_groups)
194
+
195
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
196
+ if attention_mask is not None:
197
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
198
+ attn_weights = attn_weights + causal_mask
199
+
200
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
201
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
202
+ attn_output = torch.matmul(attn_weights, value_states)
203
+ attn_output = attn_output.transpose(1, 2).contiguous()
204
+
205
+ return attn_output, attn_weights
206
+
207
+
208
+ @use_kernelized_func(apply_rotary_pos_emb)
209
+ class LasrEncoderAttention(nn.Module):
210
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
211
+
212
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
213
+ super().__init__()
214
+ self.config = config
215
+ self.layer_idx = layer_idx
216
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
217
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
218
+ self.scaling = self.head_dim**-0.5
219
+ self.attention_dropout = config.attention_dropout
220
+ self.is_causal = False
221
+
222
+ self.q_proj = nn.Linear(
223
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
224
+ )
225
+ self.k_proj = nn.Linear(
226
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
227
+ )
228
+ self.v_proj = nn.Linear(
229
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
230
+ )
231
+ self.o_proj = nn.Linear(
232
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
233
+ )
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states: torch.Tensor,
238
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ **kwargs: Unpack[TransformersKwargs],
241
+ ) -> tuple[torch.Tensor, torch.Tensor]:
242
+ input_shape = hidden_states.shape[:-1]
243
+ hidden_shape = (*input_shape, -1, self.head_dim)
244
+
245
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
246
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
247
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
248
+
249
+ cos, sin = position_embeddings
250
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
251
+
252
+ attention_interface: Callable = eager_attention_forward
253
+ if self.config._attn_implementation != "eager":
254
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
255
+
256
+ attn_output, attn_weights = attention_interface(
257
+ self,
258
+ query_states,
259
+ key_states,
260
+ value_states,
261
+ attention_mask,
262
+ dropout=0.0 if not self.training else self.attention_dropout,
263
+ scaling=self.scaling,
264
+ **kwargs,
265
+ )
266
+
267
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
268
+ attn_output = self.o_proj(attn_output)
269
+ return attn_output, attn_weights
270
+
271
+
272
+ class LasrEncoderConvolutionModule(nn.Module):
273
+ def __init__(self, config: LasrEncoderConfig, module_config=None):
274
+ """
275
+ Args:
276
+ config (LasrEncoderConfig): Configuration for the model.
277
+ module_config (dict): Configuration for the module (e.g., encoder or decoder).
278
+ """
279
+ super().__init__()
280
+ channels = config.hidden_size
281
+ # kernel_size should be an odd number for 'SAME' padding
282
+ if module_config is None:
283
+ # e.g. using `LasrEncoderEncoderConfig` in src/transformers/models/lasr_encoder/configuration_lasr_encoder.py
284
+ kernel_size = config.conv_kernel_size
285
+ self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
286
+ else:
287
+ kernel_size = module_config["kernel_size"]
288
+ self.activation = ACT2FN[module_config.get("activation", "silu")]
289
+ self.padding = "same"
290
+ self.pointwise_conv1 = nn.Conv1d(
291
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
292
+ )
293
+ self.depthwise_conv = nn.Conv1d(
294
+ channels,
295
+ channels,
296
+ kernel_size,
297
+ stride=1,
298
+ padding=self.padding,
299
+ groups=channels,
300
+ bias=config.convolution_bias,
301
+ )
302
+ self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
303
+ self.pointwise_conv2 = nn.Conv1d(
304
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
305
+ )
306
+
307
+ def forward(self, hidden_states, attention_mask=None):
308
+ """
309
+ Compute convolution module.
310
+
311
+ Args:
312
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
313
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
314
+
315
+ Returns:
316
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
317
+
318
+ """
319
+ # exchange the temporal dimension and the feature dimension
320
+ hidden_states = hidden_states.transpose(1, 2)
321
+
322
+ # GLU mechanism, (batch_size, 2*channel, dim)
323
+ hidden_states = self.pointwise_conv1(hidden_states)
324
+ # (batch_size, channel, dim)
325
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
326
+
327
+ # Apply padding mask before convolution
328
+ if attention_mask is not None:
329
+ if attention_mask.dtype == torch.bool:
330
+ all_masked_rows = torch.all(~attention_mask, dim=2)
331
+ else:
332
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
333
+ hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
334
+
335
+ # 1D Depthwise Conv
336
+ hidden_states = self.depthwise_conv(hidden_states)
337
+ hidden_states = self.norm(hidden_states)
338
+ hidden_states = self.activation(hidden_states)
339
+ hidden_states = self.pointwise_conv2(hidden_states)
340
+
341
+ return hidden_states.transpose(1, 2)
342
+
343
+
344
+ class LasrEncoderFeedForward(nn.Module):
345
+ def __init__(self, config: LasrEncoderConfig):
346
+ super().__init__()
347
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
348
+ self.activation = ACT2FN[config.hidden_act]
349
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
350
+ self.activation_dropout = config.activation_dropout
351
+
352
+ def forward(self, hidden_states):
353
+ hidden_states = self.activation(self.linear1(hidden_states))
354
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
355
+ hidden_states = self.linear2(hidden_states)
356
+ return hidden_states
357
+
358
+
359
+ class LasrEncoderBlock(GradientCheckpointingLayer):
360
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
361
+ super().__init__()
362
+ self.gradient_checkpointing = False
363
+
364
+ self.feed_forward1 = LasrEncoderFeedForward(config)
365
+ self.self_attn = LasrEncoderAttention(config, layer_idx)
366
+ self.conv = LasrEncoderConvolutionModule(config)
367
+ self.feed_forward2 = LasrEncoderFeedForward(config)
368
+
369
+ self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
370
+ self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
371
+ self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
372
+ self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
373
+ self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
374
+
375
+ self.feed_forward_residual_weights = config.feed_forward_residual_weights
376
+ self.conv_residual_weights = config.conv_residual_weights
377
+
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_embeddings: Optional[torch.Tensor] = None,
383
+ **kwargs: Unpack[TransformersKwargs],
384
+ ) -> torch.Tensor:
385
+ residual = hidden_states
386
+ hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
387
+ hidden_states = (
388
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
389
+ )
390
+
391
+ normalized_hidden_states = self.norm_self_att(hidden_states)
392
+ attn_output, _ = self.self_attn(
393
+ hidden_states=normalized_hidden_states,
394
+ attention_mask=attention_mask,
395
+ position_embeddings=position_embeddings,
396
+ **kwargs,
397
+ )
398
+ hidden_states = hidden_states + attn_output
399
+
400
+ conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
401
+ hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
402
+
403
+ residual = hidden_states
404
+ hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
405
+ hidden_states = (
406
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
407
+ )
408
+
409
+ hidden_states = self.norm_out(hidden_states)
410
+
411
+ return hidden_states
412
+
413
+
414
+ @auto_docstring
415
+ class LasrPreTrainedModel(PreTrainedModel):
416
+ config: LasrCTCConfig
417
+ base_model_prefix = "model"
418
+ main_input_name = "input_features"
419
+ input_modalities = "audio"
420
+ supports_gradient_checkpointing = True
421
+ _no_split_modules = ["LasrEncoderBlock"]
422
+ _supports_flat_attention_mask = True
423
+ _supports_sdpa = True
424
+ _supports_flex_attn = True
425
+
426
+ # TODO: @eustlb, add support when flash attention supports custom attention bias
427
+ _supports_flash_attn = False
428
+
429
+ _can_compile_fullgraph = True
430
+ _supports_attention_backend = True
431
+ _can_record_outputs = {
432
+ "hidden_states": LasrEncoderBlock,
433
+ "attentions": LasrEncoderAttention,
434
+ }
435
+
436
+ @torch.no_grad()
437
+ def _init_weights(self, module):
438
+ super()._init_weights(module)
439
+
440
+ def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
441
+ encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
442
+ kernel_size = encoder_config.subsampling_conv_kernel_size
443
+ stride = encoder_config.subsampling_conv_stride
444
+
445
+ num_layers = 2
446
+ for _ in range(num_layers):
447
+ input_lengths = (input_lengths - kernel_size) // stride + 1
448
+
449
+ return input_lengths
450
+
451
+ def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
452
+ """
453
+ Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
454
+ when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
455
+ """
456
+ output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
457
+ # Use target_length if provided, otherwise use max length in batch
458
+ max_length = target_length if target_length is not None else output_lengths.max()
459
+ attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
460
+ return attention_mask
461
+
462
+
463
+ @auto_docstring(
464
+ custom_intro="""
465
+ The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
466
+ """
467
+ )
468
+ class LasrEncoder(LasrPreTrainedModel):
469
+ config: LasrEncoderConfig
470
+ base_model_prefix = "encoder"
471
+
472
+ def __init__(self, config: LasrEncoderConfig):
473
+ super().__init__(config)
474
+ self.gradient_checkpointing = False
475
+
476
+ self.dropout = config.dropout
477
+ self.dropout_positions = config.dropout_positions
478
+ self.layerdrop = config.layerdrop
479
+
480
+ self.subsampler = LasrEncoderSubsampling(config)
481
+ self.rotary_emb = LasrEncoderRotaryEmbedding(config)
482
+ self.layers = nn.ModuleList(
483
+ [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
484
+ )
485
+ self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
486
+
487
+ self.post_init()
488
+
489
+ @auto_docstring
490
+ @check_model_inputs()
491
+ @can_return_tuple
492
+ def forward(
493
+ self,
494
+ input_features: torch.Tensor,
495
+ attention_mask: Optional[torch.Tensor] = None,
496
+ **kwargs: Unpack[TransformersKwargs],
497
+ ) -> BaseModelOutput:
498
+ r"""
499
+ Example:
500
+
501
+ ```python
502
+ >>> from transformers import AutoProcessor, LasrEncoder
503
+ >>> from datasets import load_dataset, Audio
504
+
505
+ >>> model_id = TODO
506
+ >>> processor = AutoProcessor.from_pretrained(model_id)
507
+ >>> encoder = ParakeetEncoder.from_pretrained(model_id)
508
+
509
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
510
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
511
+
512
+ >>> inputs = processor(ds[0]["audio"]["array"])
513
+ >>> encoder_outputs = encoder(**inputs)
514
+
515
+ >>> print(encoder_outputs.last_hidden_state.shape)
516
+ ```
517
+ """
518
+
519
+ hidden_states = self.subsampler(input_features)
520
+ cos, sin = self.rotary_emb(
521
+ hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
522
+ )
523
+
524
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
525
+ cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
526
+ sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
527
+
528
+ if attention_mask is not None:
529
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
530
+
531
+ attention_mask = create_bidirectional_mask(
532
+ config=self.config,
533
+ input_embeds=hidden_states,
534
+ attention_mask=attention_mask,
535
+ )
536
+
537
+ for encoder_layer in self.layers:
538
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
539
+ to_drop = False
540
+ if self.training:
541
+ dropout_probability = torch.rand([])
542
+ if dropout_probability < self.layerdrop: # skip the layer
543
+ to_drop = True
544
+
545
+ if not to_drop:
546
+ hidden_states = encoder_layer(
547
+ hidden_states,
548
+ attention_mask=attention_mask,
549
+ position_embeddings=(cos, sin),
550
+ **kwargs,
551
+ )
552
+
553
+ hidden_states = self.out_norm(hidden_states)
554
+
555
+ return BaseModelOutput(last_hidden_state=hidden_states)
556
+
557
+
558
+ @dataclass
559
+ class LasrGenerateOutput(ModelOutput):
560
+ """
561
+ Outputs of Lasr models.
562
+
563
+ Args:
564
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
565
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
566
+ if all batches finished early due to the `eos_token_id`.
567
+ logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
568
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
569
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
570
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
571
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
572
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
573
+ `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
574
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
575
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
576
+ `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
577
+ """
578
+
579
+ sequences: torch.LongTensor
580
+ logits: Optional[tuple[torch.FloatTensor]] = None
581
+ attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
582
+ hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
583
+
584
+
585
+ @auto_docstring(
586
+ custom_intro="""
587
+ Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
588
+ """
589
+ )
590
+ class LasrForCTC(LasrPreTrainedModel):
591
+ config: LasrCTCConfig
592
+
593
+ def __init__(self, config: LasrCTCConfig):
594
+ super().__init__(config)
595
+ self.encoder = LasrEncoder(config.encoder_config)
596
+ # Conv rather than linear to be consistent with NeMO decoding layer
597
+ self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
598
+
599
+ self.post_init()
600
+
601
+ @auto_docstring
602
+ @can_return_tuple
603
+ def forward(
604
+ self,
605
+ input_features: torch.Tensor,
606
+ attention_mask: Optional[torch.Tensor] = None,
607
+ labels: Optional[torch.Tensor] = None,
608
+ **kwargs: Unpack[TransformersKwargs],
609
+ ) -> CausalLMOutput:
610
+ r"""
611
+ Example:
612
+
613
+ ```python
614
+ >>> from transformers import AutoProcessor, LasrForCTC
615
+ >>> from datasets import load_dataset, Audio
616
+
617
+ >>> model_id = "nvidia/lasr-ctc-1.1b"
618
+ >>> processor = AutoProcessor.from_pretrained(model_id)
619
+ >>> model = LasrForCTC.from_pretrained(model_id)
620
+
621
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
622
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
623
+
624
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
625
+ >>> outputs = model(**inputs)
626
+
627
+ >>> print(outputs.loss)
628
+ ```"""
629
+
630
+ encoder_outputs = self.encoder(
631
+ input_features=input_features,
632
+ attention_mask=attention_mask,
633
+ **kwargs,
634
+ )
635
+
636
+ hidden_states = encoder_outputs.last_hidden_state
637
+ logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
638
+
639
+ loss = None
640
+ if labels is not None:
641
+ # retrieve loss input_lengths from attention_mask
642
+ attention_mask = (
643
+ attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
644
+ )
645
+ input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
646
+
647
+ # assuming that padded tokens are filled with -100
648
+ # when not being attended to
649
+ labels_mask = labels != self.config.pad_token_id
650
+ target_lengths = labels_mask.sum(-1)
651
+ flattened_targets = labels.masked_select(labels_mask)
652
+
653
+ # ctc_loss doesn't support fp16
654
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
655
+
656
+ with torch.backends.cudnn.flags(enabled=False):
657
+ loss = nn.functional.ctc_loss(
658
+ log_probs,
659
+ flattened_targets,
660
+ input_lengths,
661
+ target_lengths,
662
+ blank=self.config.pad_token_id,
663
+ reduction=self.config.ctc_loss_reduction,
664
+ zero_infinity=self.config.ctc_zero_infinity,
665
+ )
666
+
667
+ return CausalLMOutput(
668
+ loss=loss,
669
+ logits=logits,
670
+ hidden_states=encoder_outputs.hidden_states,
671
+ attentions=encoder_outputs.attentions,
672
+ )
673
+
674
+ @torch.no_grad()
675
+ def generate(
676
+ self,
677
+ input_features: torch.Tensor,
678
+ attention_mask: Optional[torch.Tensor] = None,
679
+ return_dict_in_generate: bool = False,
680
+ **kwargs: Unpack[TransformersKwargs],
681
+ ) -> Union[LasrGenerateOutput, torch.LongTensor]:
682
+ r"""
683
+ Example:
684
+
685
+ ```python
686
+ >>> from transformers import AutoProcessor, LasrForCTC
687
+ >>> from datasets import load_dataset, Audio
688
+
689
+ >>> model_id = TODO
690
+ >>> processor = AutoProcessor.from_pretrained(model_id)
691
+ >>> model = LasrForCTC.from_pretrained(model_id)
692
+
693
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
694
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
695
+
696
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
697
+ >>> predicted_ids = model.generate(**inputs)
698
+ >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
699
+
700
+ >>> print(transcription)
701
+ ```
702
+ """
703
+ kwargs["return_dict"] = True
704
+ outputs: CausalLMOutput = self.forward(
705
+ input_features=input_features,
706
+ attention_mask=attention_mask,
707
+ **kwargs,
708
+ )
709
+
710
+ # greedy decoding
711
+ sequences = outputs.logits.argmax(dim=-1)
712
+
713
+ # mask out padded tokens
714
+ if attention_mask is not None:
715
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
716
+ sequences[~attention_mask] = self.config.pad_token_id
717
+
718
+ if return_dict_in_generate:
719
+ return LasrGenerateOutput(
720
+ sequences=sequences,
721
+ logits=outputs.logits,
722
+ attentions=outputs.attentions,
723
+ hidden_states=outputs.hidden_states,
724
+ )
725
+
726
+ return sequences
727
+
728
+
729
+ __all__ = ["LasrForCTC", "LasrEncoder", "LasrPreTrainedModel"]