transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ from transformers.activations import ACT2FN
35
35
  from ... import initialization as init
36
36
  from ...cache_utils import Cache
37
37
  from ...generation import GenerationMixin
38
- from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
38
+ from ...integrations import lazy_load_kernel, use_kernel_forward_from_hub, use_kernelized_func
39
39
  from ...modeling_attn_mask_utils import AttentionMaskConverter
40
40
  from ...modeling_layers import GradientCheckpointingLayer
41
41
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -44,22 +44,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
44
  from ...processing_utils import Unpack
45
45
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
46
46
  from ...utils.generic import maybe_autocast
47
- from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
48
47
  from .configuration_bamba import BambaConfig
49
48
 
50
49
 
51
- if is_mamba_2_ssm_available():
52
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
53
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
54
- else:
55
- selective_state_update = None
56
-
57
- if is_causal_conv1d_available():
58
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
59
- else:
60
- causal_conv1d_update, causal_conv1d_fn = None, None
61
-
62
-
63
50
  logger = logging.get_logger(__name__)
64
51
 
65
52
 
@@ -212,7 +199,7 @@ class BambaRotaryEmbedding(nn.Module):
212
199
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
213
200
 
214
201
  self.register_buffer("inv_freq", inv_freq, persistent=False)
215
- self.original_inv_freq = inv_freq
202
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
216
203
 
217
204
  @staticmethod
218
205
  def compute_default_rope_parameters(
@@ -501,9 +488,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
501
488
  return hidden_states
502
489
 
503
490
 
504
- is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
505
-
506
-
507
491
  # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
508
492
  class BambaMixer(nn.Module):
509
493
  """
@@ -575,6 +559,20 @@ class BambaMixer(nn.Module):
575
559
 
576
560
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
577
561
 
562
+ global causal_conv1d_update, causal_conv1d_fn
563
+ causal_conv1d = lazy_load_kernel("causal-conv1d")
564
+ causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
565
+ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
566
+
567
+ global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
568
+ mamba_ssm = lazy_load_kernel("mamba-ssm")
569
+ selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
570
+ mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
571
+ mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
572
+
573
+ global is_fast_path_available
574
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
575
+
578
576
  if not is_fast_path_available:
579
577
  logger.warning_once(
580
578
  "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
@@ -1489,6 +1487,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
1489
1487
  cache_position=None,
1490
1488
  position_ids=None,
1491
1489
  use_cache=True,
1490
+ is_first_iteration=False,
1492
1491
  **kwargs,
1493
1492
  ):
1494
1493
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -1521,7 +1520,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
1521
1520
  position_ids = position_ids[:, -input_ids.shape[1] :]
1522
1521
 
1523
1522
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1524
- if inputs_embeds is not None and empty_past_kv:
1523
+ if inputs_embeds is not None and is_first_iteration:
1525
1524
  model_inputs = {"inputs_embeds": inputs_embeds}
1526
1525
  else:
1527
1526
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -43,6 +43,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
43
43
  )
44
44
 
45
45
  from ... import initialization as init
46
+ from ...integrations import lazy_load_kernel
46
47
  from ...modeling_attn_mask_utils import AttentionMaskConverter
47
48
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
48
49
  from ...modeling_utils import PreTrainedModel
@@ -52,24 +53,9 @@ from ...utils import (
52
53
  can_return_tuple,
53
54
  logging,
54
55
  )
55
- from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
56
56
  from .configuration_bamba import BambaConfig
57
57
 
58
58
 
59
- if is_mamba_2_ssm_available():
60
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
61
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
62
- else:
63
- selective_state_update = None
64
-
65
- if is_causal_conv1d_available():
66
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
67
- else:
68
- causal_conv1d_update, causal_conv1d_fn = None, None
69
-
70
- is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
71
-
72
-
73
59
  logger = logging.get_logger(__name__)
74
60
 
75
61
 
@@ -276,6 +262,20 @@ class BambaMixer(nn.Module):
276
262
 
277
263
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
278
264
 
265
+ global causal_conv1d_update, causal_conv1d_fn
266
+ causal_conv1d = lazy_load_kernel("causal-conv1d")
267
+ causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
268
+ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
269
+
270
+ global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
271
+ mamba_ssm = lazy_load_kernel("mamba-ssm")
272
+ selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
273
+ mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
274
+ mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
275
+
276
+ global is_fast_path_available
277
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
278
+
279
279
  if not is_fast_path_available:
280
280
  logger.warning_once(
281
281
  "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
@@ -1151,6 +1151,7 @@ class BambaForCausalLM(LlamaForCausalLM):
1151
1151
  cache_position=None,
1152
1152
  position_ids=None,
1153
1153
  use_cache=True,
1154
+ is_first_iteration=False,
1154
1155
  **kwargs,
1155
1156
  ):
1156
1157
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -1183,7 +1184,7 @@ class BambaForCausalLM(LlamaForCausalLM):
1183
1184
  position_ids = position_ids[:, -input_ids.shape[1] :]
1184
1185
 
1185
1186
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1186
- if inputs_embeds is not None and empty_past_kv:
1187
+ if inputs_embeds is not None and is_first_iteration:
1187
1188
  model_inputs = {"inputs_embeds": inputs_embeds}
1188
1189
  else:
1189
1190
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import functional as F
25
25
 
26
+ from ... import initialization as init
26
27
  from ...cache_utils import Cache, DynamicCache
27
28
  from ...generation import GenerationMixin
28
29
  from ...generation.logits_process import (
@@ -349,6 +350,14 @@ class BarkPreTrainedModel(PreTrainedModel):
349
350
 
350
351
  return super().device
351
352
 
353
+ def _init_weights(self, module):
354
+ super()._init_weights(module)
355
+ if isinstance(module, BarkSelfAttention):
356
+ if module.is_causal:
357
+ block_size = module.config.block_size
358
+ bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
359
+ init.copy_(module.bias, bias)
360
+
352
361
 
353
362
  # GPT2-like autoregressive model
354
363
  class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
@@ -157,7 +157,6 @@ class BartConfig(PreTrainedConfig):
157
157
  decoder_start_token_id=decoder_start_token_id,
158
158
  **kwargs,
159
159
  )
160
- self.tie_encoder_decoder = True
161
160
 
162
161
 
163
162
  __all__ = ["BartConfig"]
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
28
29
  from ...generation import GenerationMixin
@@ -476,6 +477,11 @@ class BartPreTrainedModel(PreTrainedModel):
476
477
 
477
478
  _can_compile_fullgraph = True
478
479
 
480
+ def _init_weights(self, module):
481
+ super()._init_weights(module)
482
+ if isinstance(module, BartForConditionalGeneration):
483
+ init.zeros_(module.final_logits_bias)
484
+
479
485
  @property
480
486
  def dummy_inputs(self):
481
487
  pad_token = self.config.pad_token_id
@@ -1463,6 +1469,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
1463
1469
  def __init__(self, config):
1464
1470
  super().__init__(config)
1465
1471
  self.decoder = BartDecoder(config)
1472
+ self.post_init()
1466
1473
 
1467
1474
  def forward(self, *args, **kwargs):
1468
1475
  return self.decoder(*args, **kwargs)
@@ -163,7 +163,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
163
163
  processed_images_grouped[shape] = stacked_images
164
164
 
165
165
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
166
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
167
166
 
168
167
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
169
168
 
@@ -569,6 +569,9 @@ class BertPreTrainedModel(PreTrainedModel):
569
569
  super()._init_weights(module)
570
570
  if isinstance(module, BertLMPredictionHead):
571
571
  init.zeros_(module.bias)
572
+ elif isinstance(module, BertEmbeddings):
573
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
574
+ init.zeros_(module.token_type_ids)
572
575
 
573
576
 
574
577
  @dataclass
@@ -463,6 +463,8 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
463
463
  super()._init_weights(module)
464
464
  if isinstance(module, BertGenerationOnlyLMHead):
465
465
  init.zeros_(module.bias)
466
+ elif isinstance(module, BertGenerationEmbeddings):
467
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
466
468
 
467
469
 
468
470
  @auto_docstring(
@@ -1521,6 +1521,9 @@ class BigBirdPreTrainedModel(PreTrainedModel):
1521
1521
  super()._init_weights(module)
1522
1522
  if isinstance(module, BigBirdLMPredictionHead):
1523
1523
  init.zeros_(module.bias)
1524
+ elif isinstance(module, BigBirdEmbeddings):
1525
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
1526
+ init.zeros_(module.token_type_ids)
1524
1527
 
1525
1528
 
1526
1529
  @dataclass
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
28
29
  from ...generation import GenerationMixin
@@ -1536,6 +1537,11 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
1536
1537
  _skip_keys_device_placement = "past_key_values"
1537
1538
  _can_compile_fullgraph = True
1538
1539
 
1540
+ def _init_weights(self, module):
1541
+ super()._init_weights(module)
1542
+ if isinstance(module, BigBirdPegasusForConditionalGeneration):
1543
+ init.zeros_(module.final_logits_bias)
1544
+
1539
1545
  @property
1540
1546
  def dummy_inputs(self):
1541
1547
  pad_token = self.config.pad_token_id
@@ -2582,6 +2588,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
2582
2588
  def __init__(self, config):
2583
2589
  super().__init__(config)
2584
2590
  self.decoder = BigBirdPegasusDecoder(config)
2591
+ self.post_init()
2585
2592
 
2586
2593
  def forward(self, *args, **kwargs):
2587
2594
  return self.decoder(*args, **kwargs)
@@ -84,7 +84,7 @@ class WeightStandardizedConv2d(nn.Conv2d):
84
84
  """Conv2d with Weight Standardization. Used for ViT Hybrid model.
85
85
 
86
86
  Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
87
- Standardization](https://huggingface.co/papers/1903.10520v2)
87
+ Standardization](https://huggingface.co/papers/1903.10520)
88
88
  """
89
89
 
90
90
  def __init__(
@@ -643,6 +643,10 @@ class BitPreTrainedModel(PreTrainedModel):
643
643
  elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
644
644
  init.constant_(module.weight, 1)
645
645
  init.constant_(module.bias, 0)
646
+ if getattr(module, "running_mean", None) is not None:
647
+ init.zeros_(module.running_mean)
648
+ init.ones_(module.running_var)
649
+ init.zeros_(module.num_batches_tracked)
646
650
 
647
651
 
648
652
  @auto_docstring
@@ -287,7 +287,7 @@ class BitNetRotaryEmbedding(nn.Module):
287
287
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
288
288
 
289
289
  self.register_buffer("inv_freq", inv_freq, persistent=False)
290
- self.original_inv_freq = inv_freq
290
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
291
291
 
292
292
  @staticmethod
293
293
  def compute_default_rope_parameters(
@@ -24,6 +24,7 @@ import torch
24
24
  from torch import nn
25
25
  from torch.nn import CrossEntropyLoss
26
26
 
27
+ from ... import initialization as init
27
28
  from ...activations import ACT2FN
28
29
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
29
30
  from ...generation import GenerationMixin
@@ -437,6 +438,11 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
437
438
  _supports_flex_attn = True
438
439
  _can_compile_fullgraph = True
439
440
 
441
+ def _init_weights(self, module):
442
+ super()._init_weights(module)
443
+ if isinstance(module, BlenderbotForConditionalGeneration):
444
+ init.zeros_(module.final_logits_bias)
445
+
440
446
  @property
441
447
  def dummy_inputs(self):
442
448
  pad_token = self.config.pad_token_id
@@ -1156,6 +1162,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
1156
1162
  def __init__(self, config):
1157
1163
  super().__init__(config)
1158
1164
  self.decoder = BlenderbotDecoder(config)
1165
+ self.post_init()
1159
1166
 
1160
1167
  def forward(self, *args, **kwargs):
1161
1168
  return self.decoder(*args, **kwargs)
@@ -160,13 +160,6 @@ class BlenderbotTokenizer(TokenizersBackend):
160
160
 
161
161
  self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
162
162
  self._tokenizer.decoder = decoders.ByteLevel()
163
- self._tokenizer.post_processor = processors.RobertaProcessing(
164
- sep=(str(eos_token), self._vocab.get(str(eos_token), 2)),
165
- cls=(str(bos_token), self._vocab.get(str(bos_token), 0)),
166
- add_prefix_space=add_prefix_space,
167
- trim_offsets=True,
168
- )
169
-
170
163
  super().__init__(
171
164
  bos_token=bos_token,
172
165
  eos_token=eos_token,
@@ -178,6 +171,12 @@ class BlenderbotTokenizer(TokenizersBackend):
178
171
  add_prefix_space=add_prefix_space,
179
172
  **kwargs,
180
173
  )
174
+ self._tokenizer.post_processor = processors.RobertaProcessing(
175
+ sep=(str(eos_token), self.eos_token_id),
176
+ cls=(str(bos_token), self.bos_token_id),
177
+ add_prefix_space=add_prefix_space,
178
+ trim_offsets=True,
179
+ )
181
180
 
182
181
 
183
182
  __all__ = ["BlenderbotTokenizer"]
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import CrossEntropyLoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
28
  from ...generation import GenerationMixin
@@ -430,6 +431,11 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
430
431
  _supports_flex_attn = True
431
432
  _can_compile_fullgraph = True
432
433
 
434
+ def _init_weights(self, module):
435
+ super()._init_weights(module)
436
+ if isinstance(module, BlenderbotSmallForConditionalGeneration):
437
+ init.zeros_(module.final_logits_bias)
438
+
433
439
  @property
434
440
  def dummy_inputs(self):
435
441
  pad_token = self.config.pad_token_id
@@ -1116,6 +1122,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
1116
1122
  def __init__(self, config):
1117
1123
  super().__init__(config)
1118
1124
  self.decoder = BlenderbotSmallDecoder(config)
1125
+ self.post_init()
1119
1126
 
1120
1127
  def forward(self, *args, **kwargs):
1121
1128
  return self.decoder(*args, **kwargs)
@@ -430,6 +430,8 @@ class BlipPreTrainedModel(PreTrainedModel):
430
430
  std = self.config.vision_config.initializer_range
431
431
  init.trunc_normal_(module.position_embedding, mean=0.0, std=std)
432
432
  init.trunc_normal_(module.class_embedding, mean=0.0, std=std)
433
+ elif isinstance(module, BlipTextEmbeddings):
434
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
433
435
 
434
436
 
435
437
  class BlipEncoder(nn.Module):
@@ -21,6 +21,7 @@ import torch
21
21
  from torch import Tensor, device, nn
22
22
  from torch.nn import CrossEntropyLoss
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
26
27
  from ...generation import GenerationMixin
@@ -504,6 +505,11 @@ class BlipTextPreTrainedModel(PreTrainedModel):
504
505
  base_model_prefix = "bert"
505
506
  _no_split_modules = []
506
507
 
508
+ def _init_weights(self, module):
509
+ super()._init_weights(module)
510
+ if isinstance(module, BlipTextEmbeddings):
511
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
512
+
507
513
 
508
514
  # Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
509
515
  class BlipTextModel(BlipTextPreTrainedModel):
@@ -740,6 +746,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
740
746
  self.cls = BlipTextOnlyMLMHead(config)
741
747
  self.label_smoothing = config.label_smoothing
742
748
 
749
+ self.post_init()
750
+
743
751
  def get_input_embeddings(self):
744
752
  return self.bert.get_input_embeddings()
745
753
 
@@ -428,6 +428,8 @@ class Blip2PreTrainedModel(PreTrainedModel):
428
428
  ),
429
429
  ):
430
430
  init.zeros_(module.query_tokens)
431
+ elif isinstance(module, Blip2TextEmbeddings):
432
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
431
433
 
432
434
 
433
435
  # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
@@ -714,36 +714,21 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
714
714
  inputs_embeds=None,
715
715
  cache_position=None,
716
716
  use_cache=True,
717
+ is_first_iteration=False,
717
718
  **kwargs,
718
719
  ):
719
720
  # Overwritten because of the fixed-shape attention mask creation
720
721
 
721
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
722
- # Exception 1: when passing input_embeds, input_ids may be missing entries
723
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
724
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
725
- # (we can't check exception 3 while compiling)
726
- # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
727
- # generate the first token for each sequence. Later use the generated Input ids for continuation.
728
- if past_key_values is not None:
729
- if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
730
- inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
731
- elif (
732
- inputs_embeds is not None # Exception 1
733
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
734
- ):
735
- input_ids = input_ids[:, -cache_position.shape[0] :]
736
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
737
- input_ids = input_ids[:, cache_position]
738
-
739
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
740
- if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
741
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
742
- else:
743
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
744
- # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
745
- # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
746
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
722
+ model_inputs = super().prepare_inputs_for_generation(
723
+ input_ids,
724
+ past_key_values=past_key_values,
725
+ attention_mask=attention_mask,
726
+ inputs_embeds=inputs_embeds,
727
+ cache_position=cache_position,
728
+ use_cache=use_cache,
729
+ is_first_iteration=is_first_iteration,
730
+ **kwargs,
731
+ )
747
732
 
748
733
  # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
749
734
  # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
@@ -753,24 +738,8 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
753
738
  diff = target_length - seq_length
754
739
 
755
740
  new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
756
- attention_mask = torch.cat(
757
- [attention_mask, new_attn_mask],
758
- dim=-1,
759
- )
760
-
761
- model_inputs.update(
762
- {
763
- "cache_position": cache_position,
764
- "past_key_values": past_key_values,
765
- "use_cache": use_cache,
766
- "attention_mask": attention_mask,
767
- }
768
- )
769
-
770
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
771
- for key, value in kwargs.items():
772
- if key not in model_inputs:
773
- model_inputs[key] = value
741
+ attention_mask = torch.cat([attention_mask, new_attn_mask], dim=-1)
742
+ model_inputs["attention_mask"] = attention_mask
774
743
 
775
744
  return model_inputs
776
745