transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ states before downsampling, which is different from the default Swin Transformer
19
19
  import collections.abc
20
20
  import math
21
21
  from dataclasses import dataclass
22
- from typing import Optional
22
+ from typing import Optional, Union
23
23
 
24
24
  import torch
25
25
  from torch import Tensor, nn
@@ -331,18 +331,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
331
331
  torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
332
332
  )
333
333
 
334
- # get pair-wise relative position index for each token inside the window
335
- coords_h = torch.arange(self.window_size[0])
336
- coords_w = torch.arange(self.window_size[1])
337
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
338
- coords_flatten = torch.flatten(coords, 1)
339
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
340
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
341
- relative_coords[:, :, 0] += self.window_size[0] - 1
342
- relative_coords[:, :, 1] += self.window_size[1] - 1
343
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
344
- relative_position_index = relative_coords.sum(-1)
345
- self.register_buffer("relative_position_index", relative_position_index)
334
+ self.register_buffer("relative_position_index", self.create_relative_position_index())
346
335
 
347
336
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
348
337
  self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -401,6 +390,20 @@ class MaskFormerSwinSelfAttention(nn.Module):
401
390
 
402
391
  return outputs
403
392
 
393
+ def create_relative_position_index(self):
394
+ # get pair-wise relative position index for each token inside the window
395
+ coords_h = torch.arange(self.window_size[0])
396
+ coords_w = torch.arange(self.window_size[1])
397
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
398
+ coords_flatten = torch.flatten(coords, 1)
399
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
400
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
401
+ relative_coords[:, :, 0] += self.window_size[0] - 1
402
+ relative_coords[:, :, 1] += self.window_size[1] - 1
403
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
404
+ relative_position_index = relative_coords.sum(-1)
405
+ return relative_position_index
406
+
404
407
 
405
408
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
406
409
  class MaskFormerSwinSelfOutput(nn.Module):
@@ -656,7 +659,7 @@ class MaskFormerSwinEncoder(nn.Module):
656
659
  output_attentions=False,
657
660
  output_hidden_states=False,
658
661
  return_dict=True,
659
- ):
662
+ ) -> Union[tuple, MaskFormerSwinBaseModelOutput]:
660
663
  all_hidden_states = () if output_hidden_states else None
661
664
  all_input_dimensions = ()
662
665
  all_self_attentions = () if output_attentions else None
@@ -711,6 +714,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
711
714
  init.zeros_(module.position_embeddings)
712
715
  elif isinstance(module, MaskFormerSwinSelfAttention):
713
716
  init.zeros_(module.relative_position_bias_table)
717
+ init.copy_(module.relative_position_index, module.create_relative_position_index())
714
718
 
715
719
 
716
720
  class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
@@ -739,7 +743,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
739
743
  interpolate_pos_encoding=False,
740
744
  return_dict=None,
741
745
  **kwargs,
742
- ):
746
+ ) -> Union[tuple, MaskFormerSwinModelOutputWithPooling]:
743
747
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
744
748
  output_hidden_states = (
745
749
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -147,6 +147,7 @@ class MBartConfig(PreTrainedConfig):
147
147
  self.use_cache = use_cache
148
148
  self.num_hidden_layers = encoder_layers
149
149
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
150
+
150
151
  super().__init__(
151
152
  pad_token_id=pad_token_id,
152
153
  bos_token_id=bos_token_id,
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
28
  from ...generation import GenerationMixin
@@ -478,6 +479,11 @@ class MBartPreTrainedModel(PreTrainedModel):
478
479
  _supports_flex_attn = True
479
480
  _can_compile_fullgraph = True
480
481
 
482
+ def _init_weights(self, module):
483
+ super()._init_weights(module)
484
+ if isinstance(module, MBartForConditionalGeneration):
485
+ init.zeros_(module.final_logits_bias)
486
+
481
487
  @property
482
488
  def dummy_inputs(self):
483
489
  pad_token = self.config.pad_token_id
@@ -1442,6 +1448,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
1442
1448
  def __init__(self, config):
1443
1449
  super().__init__(config)
1444
1450
  self.decoder = MBartDecoder(config)
1451
+ self.post_init()
1445
1452
 
1446
1453
  def forward(self, *args, **kwargs):
1447
1454
  return self.decoder(*args, **kwargs)
@@ -528,6 +528,8 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
528
528
  super()._init_weights(module)
529
529
  if isinstance(module, MegatronBertLMPredictionHead):
530
530
  init.zeros_(module.bias)
531
+ elif isinstance(module, MegatronBertEmbeddings):
532
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
531
533
 
532
534
 
533
535
  @dataclass
@@ -306,11 +306,13 @@ class MetaClip2PreTrainedModel(PreTrainedModel):
306
306
  if isinstance(module, MetaClip2TextEmbeddings):
307
307
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
308
308
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
309
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
309
310
  elif isinstance(module, MetaClip2VisionEmbeddings):
310
311
  factor = self.config.initializer_factor
311
312
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
312
313
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
313
314
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
315
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
314
316
  elif isinstance(module, MetaClip2Attention):
315
317
  factor = self.config.initializer_factor
316
318
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -225,11 +225,13 @@ class MetaClip2PreTrainedModel(CLIPPreTrainedModel):
225
225
  if isinstance(module, MetaClip2TextEmbeddings):
226
226
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
227
227
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
228
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
228
229
  elif isinstance(module, MetaClip2VisionEmbeddings):
229
230
  factor = self.config.initializer_factor
230
231
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
231
232
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
232
233
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
234
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
233
235
  elif isinstance(module, MetaClip2Attention):
234
236
  factor = self.config.initializer_factor
235
237
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -521,7 +521,7 @@ class MimiRotaryEmbedding(nn.Module):
521
521
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
522
522
 
523
523
  self.register_buffer("inv_freq", inv_freq, persistent=False)
524
- self.original_inv_freq = inv_freq
524
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
525
525
 
526
526
  @staticmethod
527
527
  def compute_default_rope_parameters(
@@ -814,8 +814,8 @@ class MimiFlashAttention2(MimiAttention):
814
814
  else torch.get_autocast_gpu_dtype()
815
815
  )
816
816
  # Handle the case where the model is quantized
817
- elif hasattr(self.config, "_pre_quantization_dtype"):
818
- target_dtype = self.config._pre_quantization_dtype
817
+ elif hasattr(self.config, "quantization_config"):
818
+ target_dtype = self.config.dtype
819
819
  else:
820
820
  target_dtype = self.q_proj.weight.dtype
821
821
 
@@ -1380,7 +1380,7 @@ class MimiPreTrainedModel(PreTrainedModel):
1380
1380
  main_input_name = "input_values"
1381
1381
  input_modalities = "audio"
1382
1382
  supports_gradient_checkpointing = True
1383
- _no_split_modules = ["MimiDecoderLayer"]
1383
+ _no_split_modules = ["MimiResidualVectorQuantizer", "MimiTransformerLayer"]
1384
1384
  _skip_keys_device_placement = "past_key_values"
1385
1385
  _supports_flash_attn = True
1386
1386
  _supports_sdpa = True
@@ -1404,6 +1404,27 @@ class MimiPreTrainedModel(PreTrainedModel):
1404
1404
  init.uniform_(module.bias, a=-k, b=k)
1405
1405
  elif isinstance(module, MimiLayerScale):
1406
1406
  init.constant_(module.scale, self.config.layer_scale_initial_scale)
1407
+ elif isinstance(module, MimiConv1d):
1408
+ kernel_size = module.conv.kernel_size[0]
1409
+ stride = module.conv.stride[0]
1410
+ dilation = module.conv.dilation[0]
1411
+ kernel_size = (kernel_size - 1) * dilation + 1
1412
+ init.constant_(module.stride, stride)
1413
+ init.constant_(module.kernel_size, kernel_size)
1414
+ init.constant_(module.padding_total, kernel_size - stride)
1415
+ elif isinstance(module, MimiEuclideanCodebook):
1416
+ init.ones_(module.initialized)
1417
+ init.ones_(module.cluster_usage)
1418
+ init.zeros_(module.embed_sum)
1419
+ elif isinstance(module, MimiRotaryEmbedding):
1420
+ rope_fn = (
1421
+ ROPE_INIT_FUNCTIONS[module.rope_type]
1422
+ if module.rope_type != "default"
1423
+ else module.compute_default_rope_parameters
1424
+ )
1425
+ buffer_value, _ = rope_fn(module.config)
1426
+ init.copy_(module.inv_freq, buffer_value)
1427
+ init.copy_(module.original_inv_freq, buffer_value)
1407
1428
 
1408
1429
 
1409
1430
  @auto_docstring(
@@ -31,7 +31,12 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
+ from ...integrations import (
35
+ use_experts_implementation,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
36
41
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
42
  from ...modeling_layers import (
@@ -271,7 +276,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
271
276
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
272
277
 
273
278
  self.register_buffer("inv_freq", inv_freq, persistent=False)
274
- self.original_inv_freq = inv_freq
279
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
275
280
 
276
281
  @staticmethod
277
282
  def compute_default_rope_parameters(
@@ -473,6 +478,7 @@ class MiniMaxTopKRouter(nn.Module):
473
478
  return router_logits, router_scores, router_indices
474
479
 
475
480
 
481
+ @use_experts_implementation
476
482
  class MiniMaxExperts(nn.Module):
477
483
  """Collection of expert weights stored as 3D tensors."""
478
484
 
@@ -596,7 +602,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
596
602
  _supports_flash_attn = True
597
603
  _supports_sdpa = True
598
604
  _supports_flex_attn = True
599
- _can_compile_fullgraph = False
605
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
600
606
  _supports_attention_backend = True
601
607
  _can_record_outputs = {
602
608
  "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
@@ -613,6 +619,13 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
613
619
  init.normal_(module.down_proj, mean=0.0, std=std)
614
620
  elif isinstance(module, MiniMaxTopKRouter):
615
621
  init.normal_(module.weight, mean=0.0, std=std)
622
+ if isinstance(module, MiniMaxLightningAttention):
623
+ slope_rate = module.get_slope_rate()
624
+ query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
625
+ init.copy_(module.slope_rate, slope_rate)
626
+ init.copy_(module.query_decay, query_decay)
627
+ init.copy_(module.key_decay, key_decay)
628
+ init.copy_(module.diagonal_decay, diagonal_decay)
616
629
 
617
630
 
618
631
  @auto_docstring
@@ -21,6 +21,7 @@ import torch
21
21
  import torch.nn.functional as F
22
22
  from torch import nn
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...cache_utils import Cache, DynamicCache
26
27
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
@@ -520,13 +521,23 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
520
521
 
521
522
 
522
523
  class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
523
- _can_compile_fullgraph = False
524
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
524
525
  _can_record_outputs = {
525
526
  "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
526
527
  "hidden_states": MiniMaxDecoderLayer,
527
528
  "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
528
529
  }
529
530
 
531
+ def _init_weights(self, module):
532
+ super()._init_weights(module)
533
+ if isinstance(module, MiniMaxLightningAttention):
534
+ slope_rate = module.get_slope_rate()
535
+ query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
536
+ init.copy_(module.slope_rate, slope_rate)
537
+ init.copy_(module.query_decay, query_decay)
538
+ init.copy_(module.key_decay, key_decay)
539
+ init.copy_(module.diagonal_decay, diagonal_decay)
540
+
530
541
 
531
542
  class MiniMaxModel(MixtralModel):
532
543
  @check_model_inputs
@@ -289,7 +289,7 @@ class MinistralRotaryEmbedding(nn.Module):
289
289
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
290
290
 
291
291
  self.register_buffer("inv_freq", inv_freq, persistent=False)
292
- self.original_inv_freq = inv_freq
292
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
293
293
 
294
294
  @staticmethod
295
295
  def compute_default_rope_parameters(
@@ -295,7 +295,7 @@ class Ministral3RotaryEmbedding(nn.Module):
295
295
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
296
296
 
297
297
  self.register_buffer("inv_freq", inv_freq, persistent=False)
298
- self.original_inv_freq = inv_freq
298
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
299
299
 
300
300
  @staticmethod
301
301
  def compute_default_rope_parameters(
@@ -285,7 +285,7 @@ class MistralRotaryEmbedding(nn.Module):
285
285
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
286
286
 
287
287
  self.register_buffer("inv_freq", inv_freq, persistent=False)
288
- self.original_inv_freq = inv_freq
288
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
289
289
 
290
290
  @staticmethod
291
291
  def compute_default_rope_parameters(
@@ -252,7 +252,9 @@ class Mistral3Model(Mistral3PreTrainedModel):
252
252
 
253
253
  image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
254
254
  downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
255
- split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
255
+ split_sizes = (
256
+ (torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio).prod(dim=-1).tolist()
257
+ )
256
258
  image_features = torch.split(image_features.squeeze(0), split_sizes)
257
259
  return image_features
258
260
 
@@ -489,6 +491,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
489
491
  attention_mask=None,
490
492
  cache_position=None,
491
493
  logits_to_keep=None,
494
+ is_first_iteration=False,
492
495
  **kwargs,
493
496
  ):
494
497
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -500,12 +503,15 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
500
503
  attention_mask=attention_mask,
501
504
  cache_position=cache_position,
502
505
  logits_to_keep=logits_to_keep,
506
+ is_first_iteration=is_first_iteration,
503
507
  **kwargs,
504
508
  )
505
509
 
506
- if cache_position[0] == 0:
507
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
508
- # Otherwise we need pixel values to be passed to model
510
+ if is_first_iteration or not kwargs.get("use_cache", True):
511
+ # Pixel values are used only in the first iteration if available
512
+ # In subsquent iterations, they are already merged with text and cached
513
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
514
+ # iteration with a question and cached system prompt (continue generate from cache)
509
515
  model_inputs["pixel_values"] = pixel_values
510
516
 
511
517
  return model_inputs
@@ -157,7 +157,9 @@ class Mistral3Model(LlavaModel):
157
157
 
158
158
  image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
159
159
  downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
160
- split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
160
+ split_sizes = (
161
+ (torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio).prod(dim=-1).tolist()
162
+ )
161
163
  image_features = torch.split(image_features.squeeze(0), split_sizes)
162
164
  return image_features
163
165
 
@@ -37,7 +37,12 @@ from ... import initialization as init
37
37
  from ...activations import ACT2FN
38
38
  from ...cache_utils import Cache, DynamicCache
39
39
  from ...generation import GenerationMixin
40
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
40
+ from ...integrations import (
41
+ use_experts_implementation,
42
+ use_kernel_forward_from_hub,
43
+ use_kernel_func_from_hub,
44
+ use_kernelized_func,
45
+ )
41
46
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
42
47
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
43
48
  from ...modeling_layers import (
@@ -50,11 +55,12 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
50
55
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
51
56
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
52
57
  from ...processing_utils import Unpack
53
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
58
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
54
59
  from ...utils.generic import OutputRecorder, maybe_autocast
55
60
  from .configuration_mixtral import MixtralConfig
56
61
 
57
62
 
63
+ @use_experts_implementation
58
64
  class MixtralExperts(nn.Module):
59
65
  """Collection of expert weights stored as 3D tensors."""
60
66
 
@@ -169,7 +175,7 @@ class MixtralRotaryEmbedding(nn.Module):
169
175
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
170
176
 
171
177
  self.register_buffer("inv_freq", inv_freq, persistent=False)
172
- self.original_inv_freq = inv_freq
178
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
173
179
 
174
180
  @staticmethod
175
181
  def compute_default_rope_parameters(
@@ -403,7 +409,9 @@ class MixtralPreTrainedModel(PreTrainedModel):
403
409
  _supports_flash_attn = True
404
410
  _supports_sdpa = True
405
411
  _supports_flex_attn = True
406
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
412
+ _can_compile_fullgraph = (
413
+ is_grouped_mm_available()
414
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
407
415
  _supports_attention_backend = True
408
416
  _can_record_outputs = {
409
417
  "router_logits": OutputRecorder(MixtralTopKRouter, index=0),
@@ -28,12 +28,13 @@ from torch import nn
28
28
  from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
+ from ...integrations import use_experts_implementation
31
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
32
33
  from ...modeling_layers import GradientCheckpointingLayer
33
34
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
34
35
  from ...modeling_utils import PreTrainedModel
35
36
  from ...processing_utils import Unpack
36
- from ...utils import TransformersKwargs, logging
37
+ from ...utils import TransformersKwargs, is_grouped_mm_available, logging
37
38
  from ...utils.generic import OutputRecorder
38
39
  from ..mistral.modeling_mistral import (
39
40
  MistralAttention,
@@ -134,6 +135,7 @@ def load_balancing_loss_func(
134
135
  return overall_loss * num_experts
135
136
 
136
137
 
138
+ @use_experts_implementation
137
139
  class MixtralExperts(nn.Module):
138
140
  """Collection of expert weights stored as 3D tensors."""
139
141
 
@@ -263,7 +265,9 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
263
265
 
264
266
 
265
267
  class MixtralPreTrainedModel(MistralPreTrainedModel):
266
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
268
+ _can_compile_fullgraph = (
269
+ is_grouped_mm_available()
270
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
267
271
  _can_record_outputs = {
268
272
  "router_logits": OutputRecorder(MixtralTopKRouter, index=0),
269
273
  "hidden_states": MixtralDecoderLayer,
@@ -55,6 +55,8 @@ class MLCDRotaryEmbedding(nn.Module):
55
55
 
56
56
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
57
57
  super().__init__()
58
+ self.dim = dim
59
+ self.theta = theta
58
60
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
59
61
  self.register_buffer("inv_freq", inv_freq, persistent=False)
60
62
 
@@ -424,6 +426,7 @@ class MLCDPreTrainedModel(PreTrainedModel):
424
426
  factor = self.config.initializer_factor
425
427
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
426
428
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
429
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
427
430
  elif isinstance(module, MLCDAttention):
428
431
  factor = self.config.initializer_factor
429
432
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -447,6 +450,9 @@ class MLCDPreTrainedModel(PreTrainedModel):
447
450
  init.ones_(module.weight)
448
451
  elif isinstance(module, nn.Linear) and module.bias is not None:
449
452
  init.zeros_(module.bias)
453
+ elif isinstance(module, MLCDRotaryEmbedding):
454
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
455
+ init.copy_(module.inv_freq, inv_freq)
450
456
 
451
457
 
452
458
  class MLCDVisionTransformer(nn.Module):
@@ -363,6 +363,7 @@ class MLCDPreTrainedModel(PreTrainedModel):
363
363
  factor = self.config.initializer_factor
364
364
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
365
365
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
366
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
366
367
  elif isinstance(module, MLCDAttention):
367
368
  factor = self.config.initializer_factor
368
369
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -386,6 +387,9 @@ class MLCDPreTrainedModel(PreTrainedModel):
386
387
  init.ones_(module.weight)
387
388
  elif isinstance(module, nn.Linear) and module.bias is not None:
388
389
  init.zeros_(module.bias)
390
+ elif isinstance(module, MLCDRotaryEmbedding):
391
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
392
+ init.copy_(module.inv_freq, inv_freq)
389
393
 
390
394
 
391
395
  class MLCDVisionTransformer(CLIPVisionTransformer):
@@ -741,7 +741,7 @@ class MllamaRotaryEmbedding(nn.Module):
741
741
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
742
742
 
743
743
  self.register_buffer("inv_freq", inv_freq, persistent=False)
744
- self.original_inv_freq = inv_freq
744
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
745
745
 
746
746
  @staticmethod
747
747
  def compute_default_rope_parameters(
@@ -847,6 +847,15 @@ class MllamaPreTrainedModel(PreTrainedModel):
847
847
  elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
848
848
  if module.is_gated:
849
849
  init.zeros_(module.gate)
850
+ elif isinstance(module, MllamaRotaryEmbedding):
851
+ rope_fn = (
852
+ ROPE_INIT_FUNCTIONS[module.rope_type]
853
+ if module.rope_type != "default"
854
+ else module.compute_default_rope_parameters
855
+ )
856
+ buffer_value, _ = rope_fn(module.config)
857
+ init.copy_(module.inv_freq, buffer_value)
858
+ init.copy_(module.original_inv_freq, buffer_value)
850
859
 
851
860
  # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
852
861
  def _update_causal_mask(
@@ -1721,6 +1730,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
1721
1730
  use_cache=False,
1722
1731
  cache_position=None,
1723
1732
  logits_to_keep=None,
1733
+ is_first_iteration=False,
1724
1734
  **kwargs,
1725
1735
  ):
1726
1736
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1738,12 +1748,13 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
1738
1748
  cross_attention_mask=cross_attention_mask,
1739
1749
  cache_position=cache_position,
1740
1750
  logits_to_keep=logits_to_keep,
1751
+ is_first_iteration=is_first_iteration,
1741
1752
  **kwargs,
1742
1753
  )
1743
1754
 
1744
1755
  # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
1745
1756
  # to compute image hidden states, otherwise they are cached within each cross attn layer
1746
- if cache_position[0] != 0:
1757
+ if not is_first_iteration and use_cache:
1747
1758
  model_inputs["pixel_values"] = None
1748
1759
  model_inputs["aspect_ratio_ids"] = None
1749
1760
  model_inputs["aspect_ratio_mask"] = None
@@ -38,7 +38,7 @@ class MMGroundingDinoConfig(PreTrainedConfig):
38
38
  documentation from [`PreTrainedConfig`] for more information.
39
39
 
40
40
  Args:
41
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
41
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
42
42
  The configuration of the backbone model.
43
43
  backbone (`str`, *optional*):
44
44
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -280,7 +280,6 @@ class MMGroundingDinoConfig(PreTrainedConfig):
280
280
  self.layer_norm_eps = layer_norm_eps
281
281
 
282
282
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
283
- self.tie_encoder_decoder = True
284
283
 
285
284
 
286
285
  __all__ = ["MMGroundingDinoConfig"]
@@ -552,7 +552,7 @@ class MMGroundingDinoPreTrainedModel(PreTrainedModel):
552
552
  elif isinstance(module, MMGroundingDinoFusionLayer):
553
553
  init.constant_(module.vision_param, 1e-4)
554
554
  init.constant_(module.text_param, 1e-4)
555
- elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
555
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
556
556
  init.normal_(module.weight, mean=0.0, std=std)
557
557
  if module.bias is not None:
558
558
  init.zeros_(module.bias)
@@ -1181,7 +1181,7 @@ class MMGroundingDinoEncoder(MMGroundingDinoPreTrainedModel):
1181
1181
  output_hidden_states=None,
1182
1182
  return_dict=None,
1183
1183
  **kwargs,
1184
- ):
1184
+ ) -> Union[tuple, MMGroundingDinoEncoderOutput]:
1185
1185
  r"""
1186
1186
  Args:
1187
1187
  vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1478,7 +1478,7 @@ class MMGroundingDinoDecoder(MMGroundingDinoPreTrainedModel):
1478
1478
  output_hidden_states=None,
1479
1479
  return_dict=None,
1480
1480
  **kwargs,
1481
- ):
1481
+ ) -> Union[tuple, MMGroundingDinoDecoderOutput]:
1482
1482
  r"""
1483
1483
  Args:
1484
1484
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
@@ -1954,7 +1954,7 @@ class MMGroundingDinoModel(MMGroundingDinoPreTrainedModel):
1954
1954
  output_hidden_states=None,
1955
1955
  return_dict=None,
1956
1956
  **kwargs,
1957
- ):
1957
+ ) -> Union[tuple, MMGroundingDinoModelOutput]:
1958
1958
  r"""
1959
1959
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
1960
1960
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -51,7 +51,7 @@ class MMGroundingDinoConfig(PreTrainedConfig):
51
51
  documentation from [`PreTrainedConfig`] for more information.
52
52
 
53
53
  Args:
54
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
54
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
55
55
  The configuration of the backbone model.
56
56
  backbone (`str`, *optional*):
57
57
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -293,7 +293,6 @@ class MMGroundingDinoConfig(PreTrainedConfig):
293
293
  self.layer_norm_eps = layer_norm_eps
294
294
 
295
295
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
296
- self.tie_encoder_decoder = True
297
296
 
298
297
 
299
298
  class MMGroundingDinoContrastiveEmbedding(GroundingDinoContrastiveEmbedding):
@@ -556,6 +556,8 @@ class MobileBertPreTrainedModel(PreTrainedModel):
556
556
  init.ones_(module.weight)
557
557
  elif isinstance(module, MobileBertLMPredictionHead):
558
558
  init.zeros_(module.bias)
559
+ elif isinstance(module, MobileBertEmbeddings):
560
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
559
561
 
560
562
 
561
563
  @dataclass
@@ -180,7 +180,6 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
180
180
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
181
181
 
182
182
  # Stack all processed images if return_tensors is specified
183
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
184
183
 
185
184
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
186
185