transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
 
17
17
  import math
18
18
  from dataclasses import dataclass
19
- from typing import Optional
19
+ from typing import Optional, Union
20
20
 
21
21
  import torch
22
22
  from torch import nn
@@ -462,7 +462,7 @@ class TvpEncoder(nn.Module):
462
462
  output_attentions: Optional[bool] = None,
463
463
  output_hidden_states: Optional[bool] = None,
464
464
  return_dict: Optional[bool] = None,
465
- ):
465
+ ) -> Union[tuple, BaseModelOutput]:
466
466
  return_dict = return_dict if return_dict is not None else self.config.return_dict
467
467
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
468
468
  output_hidden_states = (
@@ -722,7 +722,7 @@ class TvpModel(TvpPreTrainedModel):
722
722
  return_dict: Optional[bool] = None,
723
723
  interpolate_pos_encoding: bool = False,
724
724
  **kwargs,
725
- ):
725
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
726
726
  r"""
727
727
  Examples:
728
728
  ```python
@@ -824,7 +824,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
824
824
  return_dict: Optional[bool] = None,
825
825
  interpolate_pos_encoding: bool = False,
826
826
  **kwargs,
827
- ):
827
+ ) -> Union[tuple, TvpVideoGroundingOutput]:
828
828
  r"""
829
829
  labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
830
830
  The labels contains duration, start time, and end time of the video corresponding to the text.
@@ -149,6 +149,7 @@ class UdopConfig(PreTrainedConfig):
149
149
  "'gated-gelu' or 'relu'"
150
150
  )
151
151
 
152
+ kwargs["tie_word_embeddings"] = True
152
153
  super().__init__(
153
154
  pad_token_id=pad_token_id,
154
155
  eos_token_id=eos_token_id,
@@ -1106,7 +1106,7 @@ class UdopStack(UdopPreTrainedModel):
1106
1106
  return_dict=None,
1107
1107
  cache_position=None,
1108
1108
  **kwargs,
1109
- ):
1109
+ ) -> Union[tuple, BaseModelOutputWithAttentionMask]:
1110
1110
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1111
1111
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1112
1112
  output_hidden_states = (
@@ -1436,12 +1436,10 @@ class UdopModel(UdopPreTrainedModel):
1436
1436
  encoder_config = deepcopy(config)
1437
1437
  encoder_config.is_decoder = False
1438
1438
  encoder_config.use_cache = False
1439
- encoder_config.tie_word_embeddings = True
1440
1439
  self.encoder = UdopStack(encoder_config)
1441
1440
 
1442
1441
  decoder_config = deepcopy(config)
1443
1442
  decoder_config.is_decoder = True
1444
- decoder_config.tie_word_embeddings = True
1445
1443
  decoder_config.num_layers = config.num_decoder_layers
1446
1444
  self.decoder = UdopStack(decoder_config)
1447
1445
 
@@ -1476,7 +1474,7 @@ class UdopModel(UdopPreTrainedModel):
1476
1474
  return_dict: Optional[bool] = None,
1477
1475
  cache_position: Optional[torch.LongTensor] = None,
1478
1476
  **kwargs,
1479
- ) -> tuple[Tensor, ...]:
1477
+ ) -> Union[tuple, Seq2SeqModelOutput]:
1480
1478
  r"""
1481
1479
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
1482
1480
  Bounding boxes of each input sequence tokens. Selected in the range `[0,
@@ -1611,12 +1609,10 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
1611
1609
  encoder_config = deepcopy(config)
1612
1610
  encoder_config.is_decoder = False
1613
1611
  encoder_config.use_cache = False
1614
- encoder_config.tie_encoder_decoder = False
1615
1612
  self.encoder = UdopStack(encoder_config)
1616
1613
 
1617
1614
  decoder_config = deepcopy(config)
1618
1615
  decoder_config.is_decoder = True
1619
- decoder_config.tie_encoder_decoder = False
1620
1616
  decoder_config.num_layers = config.num_decoder_layers
1621
1617
  self.decoder = UdopStack(decoder_config)
1622
1618
 
@@ -1655,7 +1651,7 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
1655
1651
  labels: Optional[Tensor] = None,
1656
1652
  cache_position: Optional[torch.LongTensor] = None,
1657
1653
  **kwargs,
1658
- ) -> tuple[Tensor, ...]:
1654
+ ) -> Union[tuple, Seq2SeqLMOutput]:
1659
1655
  r"""
1660
1656
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
1661
1657
  Bounding boxes of each input sequence tokens. Selected in the range `[0,
@@ -94,7 +94,6 @@ class UMT5Config(PreTrainedConfig):
94
94
  is_encoder_decoder=True,
95
95
  use_cache=True,
96
96
  tokenizer_class="T5Tokenizer",
97
- tie_word_embeddings=True,
98
97
  pad_token_id=0,
99
98
  eos_token_id=1,
100
99
  decoder_start_token_id=0,
@@ -133,10 +132,11 @@ class UMT5Config(PreTrainedConfig):
133
132
  if feed_forward_proj == "gated-gelu":
134
133
  self.dense_act_fn = "gelu_new"
135
134
 
135
+ # Force because official weights have False serialized, but we have to tie always
136
+ kwargs["tie_word_embeddings"] = True
136
137
  super().__init__(
137
138
  is_encoder_decoder=is_encoder_decoder,
138
139
  tokenizer_class=tokenizer_class,
139
- tie_word_embeddings=tie_word_embeddings,
140
140
  pad_token_id=pad_token_id,
141
141
  eos_token_id=eos_token_id,
142
142
  decoder_start_token_id=decoder_start_token_id,
@@ -929,12 +929,10 @@ class UMT5Model(UMT5PreTrainedModel):
929
929
  encoder_config = copy.deepcopy(config)
930
930
  encoder_config.is_decoder = False
931
931
  encoder_config.use_cache = False
932
- encoder_config.tie_encoder_decoder = False
933
932
  self.encoder = UMT5Stack(encoder_config)
934
933
 
935
934
  decoder_config = copy.deepcopy(config)
936
935
  decoder_config.is_decoder = True
937
- decoder_config.tie_encoder_decoder = False
938
936
  decoder_config.num_layers = config.num_decoder_layers
939
937
  self.decoder = UMT5Stack(decoder_config)
940
938
 
@@ -1108,12 +1106,10 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
1108
1106
  encoder_config = copy.deepcopy(config)
1109
1107
  encoder_config.is_decoder = False
1110
1108
  encoder_config.use_cache = False
1111
- encoder_config.tie_encoder_decoder = False
1112
1109
  self.encoder = UMT5Stack(encoder_config)
1113
1110
 
1114
1111
  decoder_config = copy.deepcopy(config)
1115
1112
  decoder_config.is_decoder = True
1116
- decoder_config.tie_encoder_decoder = False
1117
1113
  decoder_config.num_layers = config.num_decoder_layers
1118
1114
  self.decoder = UMT5Stack(decoder_config)
1119
1115
 
@@ -1614,12 +1610,10 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
1614
1610
  encoder_config = copy.deepcopy(config)
1615
1611
  encoder_config.is_decoder = False
1616
1612
  encoder_config.use_cache = False
1617
- encoder_config.tie_encoder_decoder = False
1618
1613
  self.encoder = UMT5Stack(encoder_config)
1619
1614
 
1620
1615
  decoder_config = copy.deepcopy(config)
1621
1616
  decoder_config.is_decoder = True
1622
- decoder_config.tie_encoder_decoder = False
1623
1617
  decoder_config.num_layers = config.num_decoder_layers
1624
1618
  self.decoder = UMT5Stack(decoder_config)
1625
1619
 
@@ -297,7 +297,7 @@ class VaultGemmaRotaryEmbedding(nn.Module):
297
297
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
298
298
 
299
299
  self.register_buffer("inv_freq", inv_freq, persistent=False)
300
- self.original_inv_freq = inv_freq
300
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
301
301
 
302
302
  @staticmethod
303
303
  def compute_default_rope_parameters(
@@ -154,8 +154,9 @@ class VideoLlama3ImageProcessor(BaseImageProcessor):
154
154
  **kwargs,
155
155
  ) -> None:
156
156
  super().__init__(**kwargs)
157
- if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
158
- raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
157
+ if size is not None:
158
+ if "shortest_edge" not in size or "longest_edge" not in size:
159
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
159
160
  else:
160
161
  size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
161
162
  # backward compatibility: override size with min_pixels and max_pixels if they are provided
@@ -25,6 +25,7 @@ import torch
25
25
  import torch.nn as nn
26
26
  from torch.nn import LayerNorm
27
27
 
28
+ from ... import initialization as init
28
29
  from ...activations import ACT2FN
29
30
  from ...cache_utils import Cache
30
31
  from ...generation import GenerationMixin
@@ -43,6 +44,8 @@ class VideoLlama3VisionRotaryEmbedding(nn.Module):
43
44
 
44
45
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
45
46
  super().__init__()
47
+ self.dim = dim
48
+ self.theta = theta
46
49
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
47
50
  self.register_buffer("inv_freq", inv_freq, persistent=False)
48
51
 
@@ -380,6 +383,12 @@ class VideoLlama3PreTrainedModel(PreTrainedModel):
380
383
  _can_compile_fullgraph = True
381
384
  _supports_attention_backend = True
382
385
 
386
+ def _init_weights(self, module):
387
+ super()._init_weights(module)
388
+ if isinstance(module, VideoLlama3VisionRotaryEmbedding):
389
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
390
+ init.copy_(module.inv_freq, inv_freq)
391
+
383
392
 
384
393
  class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
385
394
  config: VideoLlama3VisionConfig
@@ -855,6 +864,7 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
855
864
  video_grid_thw: Optional[torch.LongTensor] = None,
856
865
  video_merge_sizes: Optional[torch.LongTensor] = None,
857
866
  video_compression_mask: Optional[torch.BoolTensor] = None,
867
+ is_first_iteration: Optional[bool] = False,
858
868
  **kwargs,
859
869
  ):
860
870
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -874,10 +884,11 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
874
884
  video_merge_sizes=video_merge_sizes,
875
885
  video_compression_mask=video_compression_mask,
876
886
  use_cache=use_cache,
887
+ is_first_iteration=is_first_iteration,
877
888
  **kwargs,
878
889
  )
879
890
 
880
- if model_inputs["cache_position"][0] != 0:
891
+ if not is_first_iteration and use_cache:
881
892
  model_inputs["pixel_values"] = None
882
893
  model_inputs["pixel_values_videos"] = None
883
894
 
@@ -21,6 +21,7 @@ import torch.nn as nn
21
21
  import torch.nn.functional as F
22
22
  from torch.nn import LayerNorm
23
23
 
24
+ from ... import initialization as init
24
25
  from ...cache_utils import Cache
25
26
  from ...configuration_utils import PreTrainedConfig
26
27
  from ...feature_extraction_utils import BatchFeature
@@ -433,6 +434,12 @@ class VideoLlama3PreTrainedModel(Qwen2VLPreTrainedModel):
433
434
  config: VideoLlama3Config
434
435
  _no_split_modules = ["VideoLlama3VisionEncoderLayer"]
435
436
 
437
+ def _init_weights(self, module):
438
+ PreTrainedModel._init_weights(self, module)
439
+ if isinstance(module, VideoLlama3VisionRotaryEmbedding):
440
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
441
+ init.copy_(module.inv_freq, inv_freq)
442
+
436
443
 
437
444
  class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
438
445
  config: VideoLlama3VisionConfig
@@ -842,6 +849,7 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
842
849
  video_grid_thw: Optional[torch.LongTensor] = None,
843
850
  video_merge_sizes: Optional[torch.LongTensor] = None,
844
851
  video_compression_mask: Optional[torch.BoolTensor] = None,
852
+ is_first_iteration: Optional[bool] = False,
845
853
  **kwargs,
846
854
  ):
847
855
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -861,10 +869,11 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
861
869
  video_merge_sizes=video_merge_sizes,
862
870
  video_compression_mask=video_compression_mask,
863
871
  use_cache=use_cache,
872
+ is_first_iteration=is_first_iteration,
864
873
  **kwargs,
865
874
  )
866
875
 
867
- if model_inputs["cache_position"][0] != 0:
876
+ if not is_first_iteration and use_cache:
868
877
  model_inputs["pixel_values"] = None
869
878
  model_inputs["pixel_values_videos"] = None
870
879
 
@@ -599,6 +599,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
599
599
  attention_mask=None,
600
600
  cache_position=None,
601
601
  logits_to_keep=None,
602
+ is_first_iteration=False,
602
603
  **kwargs,
603
604
  ):
604
605
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -610,12 +611,15 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
610
611
  attention_mask=attention_mask,
611
612
  cache_position=cache_position,
612
613
  logits_to_keep=logits_to_keep,
614
+ is_first_iteration=is_first_iteration,
613
615
  **kwargs,
614
616
  )
615
617
 
616
- if cache_position[0] == 0:
617
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
618
- # Otherwise we need pixel values to be passed to model
618
+ if is_first_iteration or not kwargs.get("use_cache", True):
619
+ # Pixel values are used only in the first iteration if available
620
+ # In subsquent iterations, they are already merged with text and cached
621
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
622
+ # iteration with a question and cached system prompt (continue generate from cache)
619
623
  model_inputs["pixel_values_images"] = pixel_values_images
620
624
  model_inputs["pixel_values_videos"] = pixel_values_videos
621
625
 
@@ -115,7 +115,7 @@ class ViltConfig(PreTrainedConfig):
115
115
  num_channels=3,
116
116
  qkv_bias=True,
117
117
  max_image_length=-1,
118
- tie_word_embeddings=False,
118
+ tie_word_embeddings=True,
119
119
  num_images=-1,
120
120
  **kwargs,
121
121
  ):
@@ -142,7 +142,7 @@ class ViltConfig(PreTrainedConfig):
142
142
  self.qkv_bias = qkv_bias
143
143
  self.max_image_length = max_image_length
144
144
  self.num_images = num_images
145
- self.tie_encoder_decoder = True
145
+ self.tie_word_embeddings = True # force it
146
146
 
147
147
 
148
148
  __all__ = ["ViltConfig"]
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import CrossEntropyLoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...modeling_layers import GradientCheckpointingLayer
28
29
  from ...modeling_outputs import (
@@ -516,6 +517,12 @@ class ViltPreTrainedModel(PreTrainedModel):
516
517
  supports_gradient_checkpointing = True
517
518
  _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
518
519
 
520
+ def _init_weights(self, module):
521
+ super()._init_weights(module)
522
+ if isinstance(module, TextEmbeddings):
523
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
524
+ init.zeros_(module.token_type_ids)
525
+
519
526
 
520
527
  @auto_docstring
521
528
  class ViltModel(ViltPreTrainedModel):
@@ -415,6 +415,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
415
415
  attention_mask=None,
416
416
  cache_position=None,
417
417
  logits_to_keep=None,
418
+ is_first_iteration=False,
418
419
  **kwargs,
419
420
  ):
420
421
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -426,12 +427,15 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
426
427
  attention_mask=attention_mask,
427
428
  cache_position=cache_position,
428
429
  logits_to_keep=logits_to_keep,
430
+ is_first_iteration=is_first_iteration,
429
431
  **kwargs,
430
432
  )
431
433
 
432
- if cache_position[0] == 0:
433
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
434
- # Otherwise we need pixel values to be passed to model
434
+ if is_first_iteration or not kwargs.get("use_cache", True):
435
+ # Pixel values are used only in the first iteration if available
436
+ # In subsquent iterations, they are already merged with text and cached
437
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
438
+ # iteration with a question and cached system prompt (continue generate from cache)
435
439
  model_inputs["pixel_values"] = pixel_values
436
440
 
437
441
  return model_inputs
@@ -473,6 +473,8 @@ class VisualBertPreTrainedModel(PreTrainedModel):
473
473
  init.ones_(module.weight)
474
474
  elif isinstance(module, VisualBertLMPredictionHead):
475
475
  init.zeros_(module.bias)
476
+ elif isinstance(module, VisualBertEmbeddings):
477
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
476
478
 
477
479
 
478
480
  @dataclass
@@ -36,7 +36,7 @@ class VitMatteConfig(PreTrainedConfig):
36
36
  documentation from [`PreTrainedConfig`] for more information.
37
37
 
38
38
  Args:
39
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
39
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitDetConfig()`):
40
40
  The configuration of the backbone model.
41
41
  backbone (`str`, *optional*):
42
42
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -152,7 +152,6 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
152
152
  processed_images_grouped[shape] = stacked_images
153
153
 
154
154
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
155
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
156
155
 
157
156
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
158
157
 
@@ -65,6 +65,10 @@ class VitMattePreTrainedModel(PreTrainedModel):
65
65
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
66
66
  if module.bias is not None:
67
67
  init.zeros_(module.bias)
68
+ if getattr(module, "running_mean", None) is not None:
69
+ init.zeros_(module.running_mean)
70
+ init.ones_(module.running_var)
71
+ init.zeros_(module.num_batches_tracked)
68
72
 
69
73
 
70
74
  class VitMatteBasicConv3x3(nn.Module):
@@ -36,7 +36,7 @@ class VitPoseConfig(PreTrainedConfig):
36
36
  documentation from [`PreTrainedConfig`] for more information.
37
37
 
38
38
  Args:
39
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
39
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitPoseBackboneConfig()`):
40
40
  The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
41
41
  backbone (`str`, *optional*):
42
42
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -156,7 +156,6 @@ class VitPoseImageProcessorFast(BaseImageProcessorFast):
156
156
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
157
157
 
158
158
  # Stack into batch tensor
159
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
160
159
 
161
160
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
162
161
 
@@ -505,11 +505,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
505
505
  # Overwritten -- we should not pass input_features when we are in cached decoding stage
506
506
 
507
507
  input_features = kwargs.pop("input_features", None)
508
- cache_position = kwargs.get("cache_position")
508
+ is_first_iteration = kwargs.get("is_first_iteration", False)
509
509
 
510
510
  model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
511
511
 
512
- if cache_position is not None and cache_position[0] == 0:
512
+ if is_first_iteration or not kwargs.get("use_cache", True):
513
513
  # input_features should only be passed when we are not in cached decoding stage
514
514
  model_inputs["input_features"] = input_features
515
515
 
@@ -267,11 +267,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
267
267
  # Overwritten -- we should not pass input_features when we are in cached decoding stage
268
268
 
269
269
  input_features = kwargs.pop("input_features", None)
270
- cache_position = kwargs.get("cache_position")
270
+ is_first_iteration = kwargs.get("is_first_iteration", False)
271
271
 
272
272
  model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
273
273
 
274
- if cache_position is not None and cache_position[0] == 0:
274
+ if is_first_iteration or not kwargs.get("use_cache", True):
275
275
  # input_features should only be passed when we are not in cached decoding stage
276
276
  model_inputs["input_features"] = input_features
277
277
 
@@ -74,18 +74,17 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
74
74
  super().__init__()
75
75
  self.max_len = config.max_source_positions
76
76
  self.d_model = config.hidden_size
77
- self.pe = None
78
- self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
77
+ self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
79
78
 
80
- def extend_pe(self, x):
79
+ def extend_pe(self, x, pe=None):
81
80
  # Reset the positional encodings
82
- if self.pe is not None:
81
+ if pe is not None:
83
82
  # self.pe contains both positive and negative parts
84
83
  # the length of self.pe is 2 * input_len - 1
85
- if self.pe.size(1) >= x.size(1) * 2 - 1:
86
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
87
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
88
- return
84
+ if pe.size(1) >= x.size(1) * 2 - 1:
85
+ if pe.dtype != x.dtype or pe.device != x.device:
86
+ pe = pe.to(dtype=x.dtype, device=x.device)
87
+ return pe
89
88
  # Suppose `i` is the position of query vector and `j` is the
90
89
  # position of key vector. We use positive relative positions when keys
91
90
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -106,10 +105,10 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
106
105
  pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
107
106
  pe_negative = pe_negative[1:].unsqueeze(0)
108
107
  pe = torch.cat([pe_positive, pe_negative], dim=1)
109
- self.pe = pe.to(device=x.device, dtype=x.dtype)
108
+ return pe.to(device=x.device, dtype=x.dtype)
110
109
 
111
110
  def forward(self, hidden_states: torch.Tensor):
112
- self.extend_pe(hidden_states)
111
+ self.pe = self.extend_pe(hidden_states, self.pe)
113
112
  start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
114
113
  end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
115
114
  relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -749,6 +748,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
749
748
  init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
750
749
  elif isinstance(module, AMSoftmaxLoss): # noqa: F821
751
750
  init.normal_(module.weight)
751
+ elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
752
+ dim = self.config.hidden_size // self.config.num_attention_heads
753
+ base = self.config.rotary_embedding_base
754
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
755
+ init.copy_(module.inv_freq, inv_freq)
756
+ elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
757
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
752
758
 
753
759
  # Ignore copy
754
760
  def _get_feat_extract_output_lengths(
@@ -621,6 +621,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
621
621
  init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
622
622
  elif isinstance(module, AMSoftmaxLoss): # noqa: F821
623
623
  init.normal_(module.weight)
624
+ elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
625
+ dim = self.config.hidden_size // self.config.num_attention_heads
626
+ base = self.config.rotary_embedding_base
627
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
628
+ init.copy_(module.inv_freq, inv_freq)
629
+ elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
630
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
624
631
 
625
632
  # Ignore copy
626
633
  def _get_feat_extract_output_lengths(
@@ -164,18 +164,17 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
164
164
  super().__init__()
165
165
  self.max_len = config.max_source_positions
166
166
  self.d_model = config.hidden_size
167
- self.pe = None
168
- self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
167
+ self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
169
168
 
170
- def extend_pe(self, x):
169
+ def extend_pe(self, x, pe=None):
171
170
  # Reset the positional encodings
172
- if self.pe is not None:
171
+ if pe is not None:
173
172
  # self.pe contains both positive and negative parts
174
173
  # the length of self.pe is 2 * input_len - 1
175
- if self.pe.size(1) >= x.size(1) * 2 - 1:
176
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
177
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
178
- return
174
+ if pe.size(1) >= x.size(1) * 2 - 1:
175
+ if pe.dtype != x.dtype or pe.device != x.device:
176
+ pe = pe.to(dtype=x.dtype, device=x.device)
177
+ return pe
179
178
  # Suppose `i` is the position of query vector and `j` is the
180
179
  # position of key vector. We use positive relative positions when keys
181
180
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -196,10 +195,10 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
196
195
  pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
197
196
  pe_negative = pe_negative[1:].unsqueeze(0)
198
197
  pe = torch.cat([pe_positive, pe_negative], dim=1)
199
- self.pe = pe.to(device=x.device, dtype=x.dtype)
198
+ return pe.to(device=x.device, dtype=x.dtype)
200
199
 
201
200
  def forward(self, hidden_states: torch.Tensor):
202
- self.extend_pe(hidden_states)
201
+ self.pe = self.extend_pe(hidden_states, self.pe)
203
202
  start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
204
203
  end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
205
204
  relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -885,15 +884,26 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
885
884
 
886
885
  if module.bias is not None:
887
886
  init.zeros_(module.bias)
888
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
887
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
889
888
  init.zeros_(module.bias)
890
889
  init.ones_(module.weight)
890
+ if getattr(module, "running_mean", None) is not None:
891
+ init.zeros_(module.running_mean)
892
+ init.ones_(module.running_var)
893
+ init.zeros_(module.num_batches_tracked)
891
894
  elif isinstance(module, nn.Conv1d):
892
895
  init.kaiming_normal_(module.weight)
893
896
 
894
897
  if module.bias is not None:
895
898
  k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
896
899
  init.uniform_(module.bias, a=-k, b=k)
900
+ elif isinstance(module, Wav2Vec2ConformerRotaryPositionalEmbedding):
901
+ dim = self.config.hidden_size // self.config.num_attention_heads
902
+ base = self.config.rotary_embedding_base
903
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
904
+ init.copy_(module.inv_freq, inv_freq)
905
+ elif isinstance(module, Wav2Vec2ConformerRelPositionalEmbedding):
906
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
897
907
 
898
908
  def _get_feat_extract_output_lengths(
899
909
  self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None