transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_lasr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from collections.abc import Callable
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+
29
+ from ...activations import ACT2FN
30
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
31
+ from ...masking_utils import create_bidirectional_mask
32
+ from ...modeling_layers import GradientCheckpointingLayer
33
+ from ...modeling_outputs import BaseModelOutput, CausalLMOutput
34
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
37
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
+ from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
40
+
41
+
42
+ class LasrEncoderSubsampling(nn.Module):
43
+ def __init__(self, config: LasrEncoderConfig):
44
+ super().__init__()
45
+ self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
46
+ self.conv_0 = nn.Conv1d(
47
+ config.hidden_size,
48
+ config.hidden_size,
49
+ kernel_size=config.subsampling_conv_kernel_size,
50
+ stride=config.subsampling_conv_stride,
51
+ )
52
+ self.conv_1 = nn.Conv1d(
53
+ config.hidden_size,
54
+ config.subsampling_conv_channels,
55
+ kernel_size=config.subsampling_conv_kernel_size,
56
+ stride=config.subsampling_conv_stride,
57
+ )
58
+ self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
59
+ self.act_fn = nn.ReLU()
60
+
61
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
62
+ hidden_states = self.act_fn(self.dense_0(input_features))
63
+ hidden_states = hidden_states.transpose(1, 2)
64
+ hidden_states = self.act_fn(self.conv_0(hidden_states))
65
+ hidden_states = self.act_fn(self.conv_1(hidden_states))
66
+ hidden_states = hidden_states.transpose(1, 2)
67
+ return self.dense_1(hidden_states)
68
+
69
+
70
+ class LasrEncoderRotaryEmbedding(nn.Module):
71
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
72
+
73
+ def __init__(self, config: LasrEncoderConfig, device=None):
74
+ super().__init__()
75
+ self.max_seq_len_cached = config.max_position_embeddings
76
+ self.original_max_seq_len = config.max_position_embeddings
77
+
78
+ self.config = config
79
+
80
+ self.rope_type = self.config.rope_parameters["rope_type"]
81
+ rope_init_fn: Callable = self.compute_default_rope_parameters
82
+ if self.rope_type != "default":
83
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
84
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
85
+
86
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
87
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
88
+
89
+ @staticmethod
90
+ def compute_default_rope_parameters(
91
+ config: Optional[LasrEncoderConfig] = None,
92
+ device: Optional["torch.device"] = None,
93
+ seq_len: Optional[int] = None,
94
+ ) -> tuple["torch.Tensor", float]:
95
+ """
96
+ Computes the inverse frequencies according to the original RoPE implementation
97
+ Args:
98
+ config ([`~transformers.PreTrainedConfig`]):
99
+ The model configuration.
100
+ device (`torch.device`):
101
+ The device to use for initialization of the inverse frequencies.
102
+ seq_len (`int`, *optional*):
103
+ The current sequence length. Unused for this type of RoPE.
104
+ Returns:
105
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
106
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
107
+ """
108
+ base = config.rope_parameters["rope_theta"]
109
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
110
+
111
+ attention_factor = 1.0 # Unused in this type of RoPE
112
+
113
+ # Compute the inverse frequencies
114
+ inv_freq = 1.0 / (
115
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
116
+ )
117
+ return inv_freq, attention_factor
118
+
119
+ @torch.no_grad()
120
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
121
+ def forward(self, x, position_ids):
122
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
123
+ position_ids_expanded = position_ids[:, None, :].float()
124
+
125
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
126
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
127
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
128
+ emb = torch.cat((freqs, freqs), dim=-1)
129
+ cos = emb.cos() * self.attention_scaling
130
+ sin = emb.sin() * self.attention_scaling
131
+
132
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
133
+
134
+
135
+ def rotate_half(x):
136
+ """Rotates half the hidden dims of the input."""
137
+ x1 = x[..., : x.shape[-1] // 2]
138
+ x2 = x[..., x.shape[-1] // 2 :]
139
+ return torch.cat((-x2, x1), dim=-1)
140
+
141
+
142
+ @use_kernel_func_from_hub("rotary_pos_emb")
143
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
144
+ """Applies Rotary Position Embedding to the query and key tensors.
145
+
146
+ Args:
147
+ q (`torch.Tensor`): The query tensor.
148
+ k (`torch.Tensor`): The key tensor.
149
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
150
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
151
+ position_ids (`torch.Tensor`, *optional*):
152
+ Deprecated and unused.
153
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
154
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
155
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
156
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
157
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
158
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
159
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
160
+ Returns:
161
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
162
+ """
163
+ cos = cos.unsqueeze(unsqueeze_dim)
164
+ sin = sin.unsqueeze(unsqueeze_dim)
165
+ q_embed = (q * cos) + (rotate_half(q) * sin)
166
+ k_embed = (k * cos) + (rotate_half(k) * sin)
167
+ return q_embed, k_embed
168
+
169
+
170
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
171
+ """
172
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
173
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
174
+ """
175
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
176
+ if n_rep == 1:
177
+ return hidden_states
178
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
179
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
180
+
181
+
182
+ def eager_attention_forward(
183
+ module: nn.Module,
184
+ query: torch.Tensor,
185
+ key: torch.Tensor,
186
+ value: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor],
188
+ scaling: float,
189
+ dropout: float = 0.0,
190
+ **kwargs: Unpack[TransformersKwargs],
191
+ ):
192
+ key_states = repeat_kv(key, module.num_key_value_groups)
193
+ value_states = repeat_kv(value, module.num_key_value_groups)
194
+
195
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
196
+ if attention_mask is not None:
197
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
198
+ attn_weights = attn_weights + causal_mask
199
+
200
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
201
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
202
+ attn_output = torch.matmul(attn_weights, value_states)
203
+ attn_output = attn_output.transpose(1, 2).contiguous()
204
+
205
+ return attn_output, attn_weights
206
+
207
+
208
+ @use_kernelized_func(apply_rotary_pos_emb)
209
+ class LasrEncoderAttention(nn.Module):
210
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
211
+
212
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
213
+ super().__init__()
214
+ self.config = config
215
+ self.layer_idx = layer_idx
216
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
217
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
218
+ self.scaling = self.head_dim**-0.5
219
+ self.attention_dropout = config.attention_dropout
220
+ self.is_causal = False
221
+
222
+ self.q_proj = nn.Linear(
223
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
224
+ )
225
+ self.k_proj = nn.Linear(
226
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
227
+ )
228
+ self.v_proj = nn.Linear(
229
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
230
+ )
231
+ self.o_proj = nn.Linear(
232
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
233
+ )
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states: torch.Tensor,
238
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ **kwargs: Unpack[TransformersKwargs],
241
+ ) -> tuple[torch.Tensor, torch.Tensor]:
242
+ input_shape = hidden_states.shape[:-1]
243
+ hidden_shape = (*input_shape, -1, self.head_dim)
244
+
245
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
246
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
247
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
248
+
249
+ cos, sin = position_embeddings
250
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
251
+
252
+ attention_interface: Callable = eager_attention_forward
253
+ if self.config._attn_implementation != "eager":
254
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
255
+
256
+ attn_output, attn_weights = attention_interface(
257
+ self,
258
+ query_states,
259
+ key_states,
260
+ value_states,
261
+ attention_mask,
262
+ dropout=0.0 if not self.training else self.attention_dropout,
263
+ scaling=self.scaling,
264
+ **kwargs,
265
+ )
266
+
267
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
268
+ attn_output = self.o_proj(attn_output)
269
+ return attn_output, attn_weights
270
+
271
+
272
+ class LasrEncoderConvolutionModule(nn.Module):
273
+ def __init__(self, config: LasrEncoderConfig, module_config=None):
274
+ """
275
+ Args:
276
+ config (LasrEncoderConfig): Configuration for the model.
277
+ module_config (dict): Configuration for the module (e.g., encoder or decoder).
278
+ """
279
+ super().__init__()
280
+ channels = config.hidden_size
281
+ # kernel_size should be an odd number for 'SAME' padding
282
+ if module_config is None:
283
+ # e.g. using `LasrEncoderEncoderConfig` in src/transformers/models/lasr_encoder/configuration_lasr_encoder.py
284
+ kernel_size = config.conv_kernel_size
285
+ self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
286
+ else:
287
+ kernel_size = module_config["kernel_size"]
288
+ self.activation = ACT2FN[module_config.get("activation", "silu")]
289
+ self.padding = "same"
290
+ self.pointwise_conv1 = nn.Conv1d(
291
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
292
+ )
293
+ self.depthwise_conv = nn.Conv1d(
294
+ channels,
295
+ channels,
296
+ kernel_size,
297
+ stride=1,
298
+ padding=self.padding,
299
+ groups=channels,
300
+ bias=config.convolution_bias,
301
+ )
302
+ self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
303
+ self.pointwise_conv2 = nn.Conv1d(
304
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
305
+ )
306
+
307
+ def forward(self, hidden_states, attention_mask=None):
308
+ """
309
+ Compute convolution module.
310
+
311
+ Args:
312
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
313
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
314
+
315
+ Returns:
316
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
317
+
318
+ """
319
+ # exchange the temporal dimension and the feature dimension
320
+ hidden_states = hidden_states.transpose(1, 2)
321
+
322
+ # GLU mechanism, (batch_size, 2*channel, dim)
323
+ hidden_states = self.pointwise_conv1(hidden_states)
324
+ # (batch_size, channel, dim)
325
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
326
+
327
+ # Apply padding mask before convolution
328
+ if attention_mask is not None:
329
+ if attention_mask.dtype == torch.bool:
330
+ all_masked_rows = torch.all(~attention_mask, dim=2)
331
+ else:
332
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
333
+ hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
334
+
335
+ # 1D Depthwise Conv
336
+ hidden_states = self.depthwise_conv(hidden_states)
337
+ hidden_states = self.norm(hidden_states)
338
+ hidden_states = self.activation(hidden_states)
339
+ hidden_states = self.pointwise_conv2(hidden_states)
340
+
341
+ return hidden_states.transpose(1, 2)
342
+
343
+
344
+ class LasrEncoderFeedForward(nn.Module):
345
+ def __init__(self, config: LasrEncoderConfig):
346
+ super().__init__()
347
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
348
+ self.activation = ACT2FN[config.hidden_act]
349
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
350
+ self.activation_dropout = config.activation_dropout
351
+
352
+ def forward(self, hidden_states):
353
+ hidden_states = self.activation(self.linear1(hidden_states))
354
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
355
+ hidden_states = self.linear2(hidden_states)
356
+ return hidden_states
357
+
358
+
359
+ class LasrEncoderBlock(GradientCheckpointingLayer):
360
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
361
+ super().__init__()
362
+ self.gradient_checkpointing = False
363
+
364
+ self.feed_forward1 = LasrEncoderFeedForward(config)
365
+ self.self_attn = LasrEncoderAttention(config, layer_idx)
366
+ self.conv = LasrEncoderConvolutionModule(config)
367
+ self.feed_forward2 = LasrEncoderFeedForward(config)
368
+
369
+ self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
370
+ self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
371
+ self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
372
+ self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
373
+ self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
374
+
375
+ self.feed_forward_residual_weights = config.feed_forward_residual_weights
376
+ self.conv_residual_weights = config.conv_residual_weights
377
+
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_embeddings: Optional[torch.Tensor] = None,
383
+ **kwargs: Unpack[TransformersKwargs],
384
+ ) -> torch.Tensor:
385
+ residual = hidden_states
386
+ hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
387
+ hidden_states = (
388
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
389
+ )
390
+
391
+ normalized_hidden_states = self.norm_self_att(hidden_states)
392
+ attn_output, _ = self.self_attn(
393
+ hidden_states=normalized_hidden_states,
394
+ attention_mask=attention_mask,
395
+ position_embeddings=position_embeddings,
396
+ **kwargs,
397
+ )
398
+ hidden_states = hidden_states + attn_output
399
+
400
+ conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
401
+ hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
402
+
403
+ residual = hidden_states
404
+ hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
405
+ hidden_states = (
406
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
407
+ )
408
+
409
+ hidden_states = self.norm_out(hidden_states)
410
+
411
+ return hidden_states
412
+
413
+
414
+ @auto_docstring
415
+ class LasrPreTrainedModel(PreTrainedModel):
416
+ config: LasrCTCConfig
417
+ base_model_prefix = "model"
418
+ main_input_name = "input_features"
419
+ input_modalities = "audio"
420
+ supports_gradient_checkpointing = True
421
+ _no_split_modules = ["LasrEncoderBlock"]
422
+ _supports_flat_attention_mask = True
423
+ _supports_sdpa = True
424
+ # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
425
+ _supports_flex_attn = False
426
+
427
+ # TODO: @eustlb, add support when flash attention supports custom attention bias
428
+ _supports_flash_attn = False
429
+
430
+ _can_compile_fullgraph = True
431
+ _supports_attention_backend = True
432
+ _can_record_outputs = {
433
+ "hidden_states": LasrEncoderBlock,
434
+ "attentions": LasrEncoderAttention,
435
+ }
436
+
437
+ @torch.no_grad()
438
+ def _init_weights(self, module):
439
+ super()._init_weights(module)
440
+
441
+ def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
442
+ encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
443
+ kernel_size = encoder_config.subsampling_conv_kernel_size
444
+ stride = encoder_config.subsampling_conv_stride
445
+
446
+ num_layers = 2
447
+ for _ in range(num_layers):
448
+ input_lengths = (input_lengths - kernel_size) // stride + 1
449
+
450
+ return input_lengths
451
+
452
+ def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
453
+ """
454
+ Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
455
+ when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
456
+ """
457
+ output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
458
+ # Use target_length if provided, otherwise use max length in batch
459
+ max_length = target_length if target_length is not None else output_lengths.max()
460
+ attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
461
+ return attention_mask
462
+
463
+
464
+ @auto_docstring(
465
+ custom_intro="""
466
+ The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
467
+ """
468
+ )
469
+ class LasrEncoder(LasrPreTrainedModel):
470
+ config: LasrEncoderConfig
471
+ base_model_prefix = "encoder"
472
+
473
+ def __init__(self, config: LasrEncoderConfig):
474
+ super().__init__(config)
475
+ self.gradient_checkpointing = False
476
+
477
+ self.dropout = config.dropout
478
+ self.dropout_positions = config.dropout_positions
479
+ self.layerdrop = config.layerdrop
480
+
481
+ self.subsampler = LasrEncoderSubsampling(config)
482
+ self.rotary_emb = LasrEncoderRotaryEmbedding(config)
483
+ self.layers = nn.ModuleList(
484
+ [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
487
+
488
+ self.post_init()
489
+
490
+ @auto_docstring
491
+ @check_model_inputs()
492
+ @can_return_tuple
493
+ def forward(
494
+ self,
495
+ input_features: torch.Tensor,
496
+ attention_mask: Optional[torch.Tensor] = None,
497
+ **kwargs: Unpack[TransformersKwargs],
498
+ ) -> BaseModelOutput:
499
+ r"""
500
+ Example:
501
+
502
+ ```python
503
+ >>> from transformers import AutoProcessor, LasrEncoder
504
+ >>> from datasets import load_dataset, Audio
505
+
506
+ >>> model_id = TODO
507
+ >>> processor = AutoProcessor.from_pretrained(model_id)
508
+ >>> encoder = ParakeetEncoder.from_pretrained(model_id)
509
+
510
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
511
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
512
+
513
+ >>> inputs = processor(ds[0]["audio"]["array"])
514
+ >>> encoder_outputs = encoder(**inputs)
515
+
516
+ >>> print(encoder_outputs.last_hidden_state.shape)
517
+ ```
518
+ """
519
+
520
+ hidden_states = self.subsampler(input_features)
521
+ cos, sin = self.rotary_emb(
522
+ hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
523
+ )
524
+
525
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
526
+ cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
527
+ sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
528
+
529
+ if attention_mask is not None:
530
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
531
+
532
+ attention_mask = create_bidirectional_mask(
533
+ config=self.config,
534
+ input_embeds=hidden_states,
535
+ attention_mask=attention_mask,
536
+ )
537
+
538
+ for encoder_layer in self.layers:
539
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
540
+ to_drop = False
541
+ if self.training:
542
+ dropout_probability = torch.rand([])
543
+ if dropout_probability < self.layerdrop: # skip the layer
544
+ to_drop = True
545
+
546
+ if not to_drop:
547
+ hidden_states = encoder_layer(
548
+ hidden_states,
549
+ attention_mask=attention_mask,
550
+ position_embeddings=(cos, sin),
551
+ **kwargs,
552
+ )
553
+
554
+ hidden_states = self.out_norm(hidden_states)
555
+
556
+ return BaseModelOutput(last_hidden_state=hidden_states)
557
+
558
+
559
+ @dataclass
560
+ class LasrGenerateOutput(ModelOutput):
561
+ """
562
+ Outputs of Lasr models.
563
+
564
+ Args:
565
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
566
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
567
+ if all batches finished early due to the `eos_token_id`.
568
+ logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
569
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
570
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
571
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
572
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
573
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
574
+ `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
575
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
576
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
577
+ `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
578
+ """
579
+
580
+ sequences: torch.LongTensor
581
+ logits: Optional[tuple[torch.FloatTensor]] = None
582
+ attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
583
+ hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
584
+
585
+
586
+ @auto_docstring(
587
+ custom_intro="""
588
+ Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
589
+ """
590
+ )
591
+ class LasrForCTC(LasrPreTrainedModel):
592
+ config: LasrCTCConfig
593
+
594
+ def __init__(self, config: LasrCTCConfig):
595
+ super().__init__(config)
596
+ self.encoder = LasrEncoder(config.encoder_config)
597
+ # Conv rather than linear to be consistent with NeMO decoding layer
598
+ self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
599
+
600
+ self.post_init()
601
+
602
+ @auto_docstring
603
+ @can_return_tuple
604
+ def forward(
605
+ self,
606
+ input_features: torch.Tensor,
607
+ attention_mask: Optional[torch.Tensor] = None,
608
+ labels: Optional[torch.Tensor] = None,
609
+ **kwargs: Unpack[TransformersKwargs],
610
+ ) -> CausalLMOutput:
611
+ r"""
612
+ Example:
613
+
614
+ ```python
615
+ >>> from transformers import AutoProcessor, LasrForCTC
616
+ >>> from datasets import load_dataset, Audio
617
+
618
+ >>> model_id = "nvidia/lasr-ctc-1.1b"
619
+ >>> processor = AutoProcessor.from_pretrained(model_id)
620
+ >>> model = LasrForCTC.from_pretrained(model_id)
621
+
622
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
623
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
624
+
625
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
626
+ >>> outputs = model(**inputs)
627
+
628
+ >>> print(outputs.loss)
629
+ ```"""
630
+
631
+ encoder_outputs = self.encoder(
632
+ input_features=input_features,
633
+ attention_mask=attention_mask,
634
+ **kwargs,
635
+ )
636
+
637
+ hidden_states = encoder_outputs.last_hidden_state
638
+ logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
639
+
640
+ loss = None
641
+ if labels is not None:
642
+ # retrieve loss input_lengths from attention_mask
643
+ attention_mask = (
644
+ attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
645
+ )
646
+ input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
647
+
648
+ # assuming that padded tokens are filled with -100
649
+ # when not being attended to
650
+ labels_mask = labels != self.config.pad_token_id
651
+ target_lengths = labels_mask.sum(-1)
652
+ flattened_targets = labels.masked_select(labels_mask)
653
+
654
+ # ctc_loss doesn't support fp16
655
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
656
+
657
+ with torch.backends.cudnn.flags(enabled=False):
658
+ loss = nn.functional.ctc_loss(
659
+ log_probs,
660
+ flattened_targets,
661
+ input_lengths,
662
+ target_lengths,
663
+ blank=self.config.pad_token_id,
664
+ reduction=self.config.ctc_loss_reduction,
665
+ zero_infinity=self.config.ctc_zero_infinity,
666
+ )
667
+
668
+ return CausalLMOutput(
669
+ loss=loss,
670
+ logits=logits,
671
+ hidden_states=encoder_outputs.hidden_states,
672
+ attentions=encoder_outputs.attentions,
673
+ )
674
+
675
+ @torch.no_grad()
676
+ def generate(
677
+ self,
678
+ input_features: torch.Tensor,
679
+ attention_mask: Optional[torch.Tensor] = None,
680
+ return_dict_in_generate: bool = False,
681
+ **kwargs: Unpack[TransformersKwargs],
682
+ ) -> Union[LasrGenerateOutput, torch.LongTensor]:
683
+ r"""
684
+ Example:
685
+
686
+ ```python
687
+ >>> from transformers import AutoProcessor, LasrForCTC
688
+ >>> from datasets import load_dataset, Audio
689
+
690
+ >>> model_id = TODO
691
+ >>> processor = AutoProcessor.from_pretrained(model_id)
692
+ >>> model = LasrForCTC.from_pretrained(model_id)
693
+
694
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
695
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
696
+
697
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
698
+ >>> predicted_ids = model.generate(**inputs)
699
+ >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
700
+
701
+ >>> print(transcription)
702
+ ```
703
+ """
704
+ kwargs["return_dict"] = True
705
+ outputs: CausalLMOutput = self.forward(
706
+ input_features=input_features,
707
+ attention_mask=attention_mask,
708
+ **kwargs,
709
+ )
710
+
711
+ # greedy decoding
712
+ sequences = outputs.logits.argmax(dim=-1)
713
+
714
+ # mask out padded tokens
715
+ if attention_mask is not None:
716
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
717
+ sequences[~attention_mask] = self.config.pad_token_id
718
+
719
+ if return_dict_in_generate:
720
+ return LasrGenerateOutput(
721
+ sequences=sequences,
722
+ logits=outputs.logits,
723
+ attentions=outputs.attentions,
724
+ hidden_states=outputs.hidden_states,
725
+ )
726
+
727
+ return sequences
728
+
729
+
730
+ __all__ = ["LasrForCTC", "LasrEncoder", "LasrPreTrainedModel"]