transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -26,8 +26,9 @@ from torch import nn
26
26
  from ... import initialization as init
27
27
  from ...activations import ACT2FN
28
28
  from ...cache_utils import Cache, DynamicCache
29
+ from ...configuration_utils import PreTrainedConfig
29
30
  from ...generation import GenerationMixin
30
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
31
+ from ...masking_utils import create_masks_for_generate
31
32
  from ...modeling_layers import GradientCheckpointingLayer
32
33
  from ...modeling_outputs import (
33
34
  BaseModelOutput,
@@ -69,6 +70,104 @@ class GitVisionModelOutput(ModelOutput):
69
70
  attentions: Optional[tuple[torch.FloatTensor, ...]] = None
70
71
 
71
72
 
73
+ # Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
74
+ def token_type_ids_mask_function(
75
+ token_type_ids: Optional[torch.Tensor],
76
+ image_group_ids: Optional[torch.Tensor],
77
+ ) -> Optional[Callable]:
78
+ """
79
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
80
+ not start and end indices.
81
+ """
82
+ # Do not return an additional mask in this case
83
+ if token_type_ids is None:
84
+ return None
85
+
86
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
87
+ # If it's 1 for both query and key/value, we are in an image block
88
+ # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
89
+ # Since vmap doesn't support `if statement` we workaround it with `torch.where`
90
+ safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
91
+ safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
92
+
93
+ token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
94
+ token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
95
+
96
+ token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
97
+ token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
98
+
99
+ image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
100
+ image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
101
+
102
+ image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
103
+ image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
104
+
105
+ is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
106
+ same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
107
+
108
+ # This is bidirectional attention whenever we are dealing with image tokens
109
+ return is_image_block & same_image_block
110
+
111
+ return inner_mask
112
+
113
+
114
+ # Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
115
+ def create_causal_mask_mapping(
116
+ config: PreTrainedConfig,
117
+ input_embeds: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor],
119
+ cache_position: torch.Tensor,
120
+ past_key_values: Optional[Cache],
121
+ position_ids: Optional[torch.Tensor],
122
+ token_type_ids: Optional[torch.Tensor] = None,
123
+ pixel_values: Optional[torch.FloatTensor] = None,
124
+ is_training: bool = False,
125
+ is_first_iteration: Optional[bool] = None,
126
+ **kwargs,
127
+ ) -> dict:
128
+ """
129
+ Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
130
+ for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
131
+
132
+ Uses `pixel_values` as an optional input to disambiguate edge cases.
133
+ """
134
+ if is_training and token_type_ids is None:
135
+ raise ValueError("`token_type_ids` is required as a model input when training")
136
+
137
+ mask_kwargs = {
138
+ "config": config.get_text_config(),
139
+ "input_embeds": input_embeds,
140
+ "attention_mask": attention_mask,
141
+ "cache_position": cache_position,
142
+ "past_key_values": past_key_values,
143
+ "position_ids": position_ids,
144
+ }
145
+ # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
146
+ # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
147
+ # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
148
+ is_first_iteration = (
149
+ is_first_iteration
150
+ if is_first_iteration is not None
151
+ else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
152
+ )
153
+ if token_type_ids is not None and is_first_iteration:
154
+ # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
155
+ # undo the causal masking)
156
+
157
+ # First find where a new image block starts: 1 if image and previous not image
158
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
159
+ is_image = (token_type_ids == 1).to(cache_position.device)
160
+ is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
161
+ new_image_start = is_image & ~is_previous_image
162
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
163
+ image_group_ids = torch.where(is_image, image_group_ids, -1)
164
+ mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
165
+ token_type_ids.to(cache_position.device), image_group_ids
166
+ )
167
+
168
+ return create_masks_for_generate(**mask_kwargs)
169
+
170
+
72
171
  class GitEmbeddings(nn.Module):
73
172
  """Construct the embeddings from word and position embeddings."""
74
173
 
@@ -148,17 +247,15 @@ class GitSelfAttention(nn.Module):
148
247
  hidden_states: torch.Tensor,
149
248
  attention_mask: Optional[torch.FloatTensor] = None,
150
249
  past_key_values: Optional[Cache] = None,
151
- output_attentions: Optional[bool] = False,
152
- pixel_values_present: Optional[bool] = False,
250
+ cache_position: Optional[torch.Tensor] = None,
153
251
  ) -> tuple[torch.Tensor]:
154
- batch_size, seq_length, _ = hidden_states.shape
252
+ batch_size = hidden_states.shape[0]
155
253
  query_layer = (
156
254
  self.query(hidden_states)
157
255
  .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
158
256
  .transpose(1, 2)
159
257
  )
160
258
 
161
- cutoff = self.image_patch_tokens if pixel_values_present else 0
162
259
  key_layer = (
163
260
  self.key(hidden_states)
164
261
  .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
@@ -170,12 +267,9 @@ class GitSelfAttention(nn.Module):
170
267
  .transpose(1, 2)
171
268
  )
172
269
  if past_key_values is not None:
173
- # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
174
- key_layer_past, value_layer_past = past_key_values.update(
175
- key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
270
+ key_layer, value_layer = past_key_values.update(
271
+ key_layer, value_layer, self.layer_idx, cache_kwargs={"cache_position": cache_position}
176
272
  )
177
- key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
178
- value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
179
273
 
180
274
  # Take the dot product between "query" and "key" to get the raw attention scores.
181
275
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@@ -232,15 +326,14 @@ class GitAttention(nn.Module):
232
326
  hidden_states: torch.Tensor,
233
327
  attention_mask: Optional[torch.FloatTensor] = None,
234
328
  past_key_values: Optional[Cache] = None,
329
+ cache_position: Optional[torch.Tensor] = None,
235
330
  output_attentions: Optional[bool] = False,
236
- pixel_values_present: Optional[bool] = False,
237
331
  ) -> tuple[torch.Tensor]:
238
332
  attn_output, self_attn_weights = self.self(
239
333
  hidden_states,
240
334
  attention_mask,
241
335
  past_key_values,
242
- output_attentions,
243
- pixel_values_present,
336
+ cache_position=cache_position,
244
337
  )
245
338
  attention_output = self.output(attn_output, hidden_states)
246
339
  return attention_output, self_attn_weights
@@ -291,8 +384,8 @@ class GitLayer(GradientCheckpointingLayer):
291
384
  hidden_states: torch.Tensor,
292
385
  attention_mask: Optional[torch.FloatTensor] = None,
293
386
  past_key_values: Optional[Cache] = None,
387
+ cache_position: Optional[torch.Tensor] = None,
294
388
  output_attentions: Optional[bool] = False,
295
- pixel_values_present: Optional[bool] = False,
296
389
  ) -> tuple[torch.Tensor]:
297
390
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
298
391
  attention_output, self_attention_weights = self.attention(
@@ -300,7 +393,7 @@ class GitLayer(GradientCheckpointingLayer):
300
393
  attention_mask,
301
394
  output_attentions=output_attentions,
302
395
  past_key_values=past_key_values,
303
- pixel_values_present=pixel_values_present,
396
+ cache_position=cache_position,
304
397
  )
305
398
 
306
399
  layer_output = apply_chunking_to_forward(
@@ -329,8 +422,8 @@ class GitEncoder(nn.Module):
329
422
  use_cache: Optional[bool] = None,
330
423
  output_attentions: Optional[bool] = False,
331
424
  output_hidden_states: Optional[bool] = False,
332
- pixel_values_present: Optional[bool] = False,
333
425
  return_dict: Optional[bool] = True,
426
+ cache_position: Optional[torch.Tensor] = None,
334
427
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
335
428
  if self.gradient_checkpointing and self.training:
336
429
  if use_cache:
@@ -353,7 +446,7 @@ class GitEncoder(nn.Module):
353
446
  attention_mask,
354
447
  past_key_values,
355
448
  output_attentions,
356
- pixel_values_present,
449
+ cache_position,
357
450
  )
358
451
 
359
452
  hidden_states = layer_outputs[0]
@@ -396,6 +489,7 @@ class GitPreTrainedModel(PreTrainedModel):
396
489
  init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
397
490
  init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
398
491
  init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
492
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
399
493
  if isinstance(module, nn.Linear):
400
494
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
401
495
  if module.bias is not None:
@@ -408,6 +502,8 @@ class GitPreTrainedModel(PreTrainedModel):
408
502
  elif isinstance(module, nn.LayerNorm):
409
503
  init.zeros_(module.bias)
410
504
  init.ones_(module.weight)
505
+ elif isinstance(module, GitEmbeddings):
506
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
411
507
 
412
508
 
413
509
  # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
@@ -903,62 +999,6 @@ class GitModel(GitPreTrainedModel):
903
999
  def set_input_embeddings(self, value):
904
1000
  self.embeddings.word_embeddings = value
905
1001
 
906
- def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
907
- # Default mask is for forward direction. Flip for backward direction.
908
- mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
909
- mask = mask.masked_fill(mask == 1, float("-inf"))
910
- return mask
911
-
912
- def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
913
- num_tgt = tgt.shape[1]
914
- num_memory = memory.shape[1]
915
- device = tgt.device
916
- dtype = tgt.dtype
917
- top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
918
- top_right = torch.full(
919
- (num_memory, num_tgt + past_key_values_length),
920
- float("-inf"),
921
- device=tgt.device,
922
- dtype=dtype,
923
- )
924
- bottom_left = torch.zeros(
925
- (num_tgt, num_memory),
926
- dtype=dtype,
927
- device=tgt_mask.device,
928
- )
929
-
930
- if past_key_values_length > 0:
931
- tgt_mask = torch.zeros(
932
- (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
933
- dtype=dtype,
934
- device=tgt_mask.device,
935
- )
936
-
937
- left = torch.cat((top_left, bottom_left), dim=0)
938
- right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
939
-
940
- full_attention_mask = torch.cat((left, right), dim=1)[None, :]
941
-
942
- if memory_key_padding_mask is None:
943
- memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
944
- # if it is False, it means valid. That is, it is not a padding
945
- if memory_key_padding_mask.dtype != torch.bool:
946
- raise ValueError("Memory key padding mask must be a boolean tensor.")
947
- zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
948
- zero_negative_infinity[memory_key_padding_mask] = float("-inf")
949
- full_attention_mask = full_attention_mask.expand(
950
- (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
951
- )
952
- full_attention_mask = full_attention_mask.clone()
953
- origin_left = full_attention_mask[:, :, :num_memory]
954
- update = zero_negative_infinity[:, None, :]
955
- full_attention_mask[:, :, :num_memory] = origin_left + update
956
-
957
- # add axis for multi-head
958
- full_attention_mask = full_attention_mask[:, None, :, :]
959
-
960
- return full_attention_mask
961
-
962
1002
  @auto_docstring
963
1003
  def forward(
964
1004
  self,
@@ -973,6 +1013,7 @@ class GitModel(GitPreTrainedModel):
973
1013
  output_hidden_states: Optional[bool] = None,
974
1014
  interpolate_pos_encoding: bool = False,
975
1015
  return_dict: Optional[bool] = None,
1016
+ cache_position: Optional[torch.Tensor] = None,
976
1017
  **kwargs,
977
1018
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
978
1019
  r"""
@@ -1005,15 +1046,6 @@ class GitModel(GitPreTrainedModel):
1005
1046
 
1006
1047
  if input_ids is not None and inputs_embeds is not None:
1007
1048
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1008
- elif input_ids is not None:
1009
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1010
- input_shape = input_ids.size()
1011
- elif inputs_embeds is not None:
1012
- input_shape = inputs_embeds.size()[:-1]
1013
- else:
1014
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1015
-
1016
- seq_length = input_shape[1]
1017
1049
 
1018
1050
  # past_key_values_length
1019
1051
  past_key_values_length = 0
@@ -1024,7 +1056,23 @@ class GitModel(GitPreTrainedModel):
1024
1056
  else past_key_values.get_seq_length()
1025
1057
  )
1026
1058
 
1027
- projected_visual_features = None
1059
+ embedding_output = self.embeddings(
1060
+ input_ids=input_ids,
1061
+ position_ids=position_ids,
1062
+ inputs_embeds=inputs_embeds,
1063
+ past_key_values_length=past_key_values_length,
1064
+ )
1065
+
1066
+ if cache_position is None:
1067
+ cache_position = torch.arange(
1068
+ past_key_values_length,
1069
+ past_key_values_length + embedding_output.shape[1],
1070
+ device=embedding_output.device,
1071
+ )
1072
+
1073
+ # Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
1074
+ token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
1075
+
1028
1076
  if pixel_values is not None:
1029
1077
  if pixel_values.ndim == 4:
1030
1078
  # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
@@ -1050,60 +1098,54 @@ class GitModel(GitPreTrainedModel):
1050
1098
 
1051
1099
  projected_visual_features = self.visual_projection(visual_features)
1052
1100
 
1053
- embedding_output = self.embeddings(
1054
- input_ids=input_ids,
1055
- position_ids=position_ids,
1056
- inputs_embeds=inputs_embeds,
1057
- past_key_values_length=past_key_values_length,
1058
- )
1059
-
1060
- if projected_visual_features is None:
1061
- projected_visual_features = torch.zeros(
1062
- (embedding_output.shape[0], 0, embedding_output.shape[2]),
1063
- dtype=embedding_output.dtype,
1064
- device=embedding_output.device,
1101
+ # Repeat visual features to match embedding batch size.
1102
+ projected_visual_features = projected_visual_features.repeat(
1103
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1065
1104
  )
1066
1105
 
1067
- # Repeat visual features to match embedding batch size.
1068
- projected_visual_features = projected_visual_features.repeat(
1069
- embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1070
- )
1071
-
1072
- # concatenate patch token and text token embeddings
1073
- hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
1074
-
1075
- # By default, an additive causal mask is created
1076
- # for masking the future (one direction).
1077
- tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
1106
+ # concatenate patch token and text token embeddings
1107
+ embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
1108
+ image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
1109
+ token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
1110
+ cache_position = torch.arange(embedding_output.shape[1], device=embedding_output.device, dtype=torch.int)
1111
+ if attention_mask is not None:
1112
+ attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
1113
+ elif past_key_values is not None and input_ids.shape[1] == 1:
1114
+ # Expand attention mask and cache position with image tokens because GIT doesn't add image
1115
+ # placeholder tokens when processing. Doesn't worth the refactor, low usage!
1116
+ cache_position = torch.tensor(
1117
+ [past_key_values_length], dtype=cache_position.dtype, device=cache_position.device
1118
+ )
1119
+ extended_attention_mask = torch.ones(
1120
+ (attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
1121
+ dtype=attention_mask.dtype,
1122
+ device=attention_mask.device,
1123
+ )
1124
+ attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
1078
1125
 
1079
- # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
1080
- combined_attention_mask = self.create_attention_mask(
1081
- tgt=embedding_output,
1082
- memory=projected_visual_features,
1083
- tgt_mask=tgt_mask,
1084
- past_key_values_length=past_key_values_length,
1126
+ # Images attend each other bidirectionally while text remains causal
1127
+ causal_mask = create_causal_mask_mapping(
1128
+ self.config,
1129
+ embedding_output,
1130
+ attention_mask,
1131
+ cache_position,
1132
+ past_key_values,
1133
+ None,
1134
+ token_type_ids,
1135
+ pixel_values,
1085
1136
  )
1086
1137
 
1087
- if attention_mask is not None:
1088
- # if the user provides an attention mask, we add it to the default one
1089
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1090
- expanded_attn_mask = _prepare_4d_attention_mask(
1091
- attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
1092
- ).to(embedding_output.device)
1093
- if past_key_values_length > 0:
1094
- expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
1095
- else:
1096
- combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
1138
+ hidden_states = embedding_output
1097
1139
 
1098
1140
  encoder_outputs = self.encoder(
1099
1141
  hidden_states,
1100
- attention_mask=combined_attention_mask,
1142
+ attention_mask=causal_mask,
1101
1143
  past_key_values=past_key_values,
1102
1144
  use_cache=use_cache,
1103
1145
  output_attentions=output_attentions,
1104
1146
  output_hidden_states=output_hidden_states,
1105
1147
  return_dict=return_dict,
1106
- pixel_values_present=pixel_values is not None,
1148
+ cache_position=cache_position,
1107
1149
  )
1108
1150
  sequence_output = encoder_outputs[0]
1109
1151
 
@@ -1157,6 +1199,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1157
1199
  interpolate_pos_encoding: bool = False,
1158
1200
  return_dict: Optional[bool] = None,
1159
1201
  logits_to_keep: Union[int, torch.Tensor] = 0,
1202
+ cache_position: Optional[torch.Tensor] = None,
1160
1203
  **kwargs,
1161
1204
  ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
1162
1205
  r"""
@@ -1306,6 +1349,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1306
1349
  output_hidden_states=output_hidden_states,
1307
1350
  interpolate_pos_encoding=interpolate_pos_encoding,
1308
1351
  return_dict=return_dict,
1352
+ cache_position=cache_position,
1309
1353
  )
1310
1354
 
1311
1355
  hidden_states = outputs[0]
@@ -1339,7 +1383,15 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1339
1383
  )
1340
1384
 
1341
1385
  def prepare_inputs_for_generation(
1342
- self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1386
+ self,
1387
+ input_ids,
1388
+ past_key_values=None,
1389
+ pixel_values=None,
1390
+ attention_mask=None,
1391
+ use_cache=None,
1392
+ cache_position=None,
1393
+ is_first_iteration=False,
1394
+ **kwargs,
1343
1395
  ):
1344
1396
  # Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
1345
1397
 
@@ -1364,11 +1416,14 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1364
1416
  model_inputs = {
1365
1417
  "input_ids": input_ids,
1366
1418
  "attention_mask": attention_mask,
1367
- "pixel_values": kwargs.get("pixel_values"),
1368
1419
  "past_key_values": past_key_values,
1369
1420
  "use_cache": use_cache,
1421
+ "cache_position": cache_position,
1370
1422
  }
1371
1423
 
1424
+ if is_first_iteration or not use_cache:
1425
+ model_inputs["pixel_values"] = pixel_values
1426
+
1372
1427
  # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
1373
1428
  for key, value in kwargs.items():
1374
1429
  if key not in model_inputs:
@@ -79,7 +79,7 @@ class GlmRotaryEmbedding(nn.Module):
79
79
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
80
80
 
81
81
  self.register_buffer("inv_freq", inv_freq, persistent=False)
82
- self.original_inv_freq = inv_freq
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
83
83
 
84
84
  @staticmethod
85
85
  def compute_default_rope_parameters(
@@ -284,7 +284,7 @@ class Glm4RotaryEmbedding(nn.Module):
284
284
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
285
285
 
286
286
  self.register_buffer("inv_freq", inv_freq, persistent=False)
287
- self.original_inv_freq = inv_freq
287
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
288
288
 
289
289
  @staticmethod
290
290
  def compute_default_rope_parameters(
@@ -354,7 +354,6 @@ class Glm46VImageProcessor(BaseImageProcessor):
354
354
  image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
355
355
  Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
356
356
  `True`.
357
- The max pixels of the image to resize the image.
358
357
  patch_size (`int`, *optional*, defaults to `self.patch_size`):
359
358
  The spatial patch size of the vision encoder.
360
359
  temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
@@ -381,12 +380,9 @@ class Glm46VImageProcessor(BaseImageProcessor):
381
380
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
382
381
 
383
382
  """
384
- # Try to use config values if set, otherwise fallback to global defaults
385
383
  size = size if size is not None else self.size
386
384
  if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
387
385
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
388
- elif size is None:
389
- size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
390
386
 
391
387
  do_resize = do_resize if do_resize is not None else self.do_resize
392
388
  resample = resample if resample is not None else self.resample
@@ -639,6 +639,7 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
639
639
  pixel_values_videos=None,
640
640
  image_grid_thw=None,
641
641
  video_grid_thw=None,
642
+ is_first_iteration=False,
642
643
  **kwargs,
643
644
  ):
644
645
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -655,13 +656,14 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
655
656
  image_grid_thw=image_grid_thw,
656
657
  video_grid_thw=video_grid_thw,
657
658
  use_cache=use_cache,
659
+ is_first_iteration=is_first_iteration,
658
660
  **kwargs,
659
661
  )
660
662
 
661
663
  # GLM-4.1V position_ids are prepareed with rope_deltas in forward
662
664
  model_inputs["position_ids"] = None
663
665
 
664
- if cache_position[0] != 0:
666
+ if not is_first_iteration and use_cache:
665
667
  model_inputs["pixel_values"] = None
666
668
  model_inputs["pixel_values_videos"] = None
667
669
 
@@ -110,6 +110,9 @@ class Glm46VPreTrainedModel(Glm4vPreTrainedModel):
110
110
  _can_record_outputs = None
111
111
  _no_split_modules = None
112
112
 
113
+ def _init_weights(self, module):
114
+ raise AttributeError("Not needed")
115
+
113
116
 
114
117
  class Glm46VModel(Glm4vModel):
115
118
  _no_split_modules = None
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
33
+ from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
36
  from ...modeling_layers import GradientCheckpointingLayer
@@ -38,7 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
38
38
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
42
42
  from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from .configuration_glm4_moe import Glm4MoeConfig
44
44
 
@@ -60,7 +60,7 @@ class Glm4MoeRotaryEmbedding(nn.Module):
60
60
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
61
61
 
62
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
63
- self.original_inv_freq = inv_freq
63
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
64
64
 
65
65
  @staticmethod
66
66
  def compute_default_rope_parameters(
@@ -332,6 +332,7 @@ class Glm4MoeRMSNorm(nn.Module):
332
332
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
333
333
 
334
334
 
335
+ @use_experts_implementation
335
336
  class Glm4MoeNaiveMoe(nn.Module):
336
337
  """Collection of expert weights stored as 3D tensors."""
337
338
 
@@ -339,7 +340,7 @@ class Glm4MoeNaiveMoe(nn.Module):
339
340
  super().__init__()
340
341
  self.num_experts = config.num_local_experts
341
342
  self.hidden_dim = config.hidden_size
342
- self.intermediate_dim = config.intermediate_size
343
+ self.intermediate_dim = config.moe_intermediate_size
343
344
  self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
344
345
  self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
345
346
  self.act_fn = ACT2FN[config.hidden_act]
@@ -486,7 +487,9 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
486
487
  _supports_flash_attn = True
487
488
  _supports_sdpa = True
488
489
  _supports_flex_attn = True
489
- _can_compile_fullgraph = False
490
+ _can_compile_fullgraph = (
491
+ is_grouped_mm_available()
492
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
490
493
  _supports_attention_backend = True
491
494
  _can_record_outputs = {
492
495
  "hidden_states": Glm4MoeDecoderLayer,
@@ -499,6 +502,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
499
502
  super()._init_weights(module)
500
503
  if isinstance(module, Glm4MoeTopkRouter):
501
504
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
505
+ init.zeros_(module.e_score_correction_bias)
502
506
  elif isinstance(module, Glm4MoeNaiveMoe):
503
507
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
504
508
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
@@ -273,7 +273,7 @@ class Glm4MoeDecoderLayer(DeepseekV3DecoderLayer):
273
273
 
274
274
 
275
275
  class Glm4MoePreTrainedModel(DeepseekV3PreTrainedModel):
276
- _can_compile_fullgraph = False
276
+ pass
277
277
 
278
278
 
279
279
  class Glm4MoeModel(DeepseekV3Model):
@@ -353,7 +353,6 @@ class Glm4vImageProcessor(BaseImageProcessor):
353
353
  image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
354
354
  Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
355
355
  `True`.
356
- The max pixels of the image to resize the image.
357
356
  patch_size (`int`, *optional*, defaults to `self.patch_size`):
358
357
  The spatial patch size of the vision encoder.
359
358
  temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
@@ -380,12 +379,9 @@ class Glm4vImageProcessor(BaseImageProcessor):
380
379
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
381
380
 
382
381
  """
383
- # Try to use config values if set, otherwise fallback to global defaults
384
382
  size = size if size is not None else self.size
385
383
  if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
386
384
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
387
- elif size is None:
388
- size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
389
385
 
390
386
  do_resize = do_resize if do_resize is not None else self.do_resize
391
387
  resample = resample if resample is not None else self.resample