transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -26,8 +26,9 @@ from torch import nn
26
26
  from ... import initialization as init
27
27
  from ...activations import ACT2FN
28
28
  from ...cache_utils import Cache, DynamicCache
29
+ from ...configuration_utils import PreTrainedConfig
29
30
  from ...generation import GenerationMixin
30
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
31
+ from ...masking_utils import create_masks_for_generate
31
32
  from ...modeling_layers import GradientCheckpointingLayer
32
33
  from ...modeling_outputs import (
33
34
  BaseModelOutput,
@@ -69,6 +70,104 @@ class GitVisionModelOutput(ModelOutput):
69
70
  attentions: Optional[tuple[torch.FloatTensor, ...]] = None
70
71
 
71
72
 
73
+ # Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
74
+ def token_type_ids_mask_function(
75
+ token_type_ids: Optional[torch.Tensor],
76
+ image_group_ids: Optional[torch.Tensor],
77
+ ) -> Optional[Callable]:
78
+ """
79
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
80
+ not start and end indices.
81
+ """
82
+ # Do not return an additional mask in this case
83
+ if token_type_ids is None:
84
+ return None
85
+
86
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
87
+ # If it's 1 for both query and key/value, we are in an image block
88
+ # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
89
+ # Since vmap doesn't support `if statement` we workaround it with `torch.where`
90
+ safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
91
+ safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
92
+
93
+ token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
94
+ token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
95
+
96
+ token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
97
+ token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
98
+
99
+ image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
100
+ image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
101
+
102
+ image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
103
+ image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
104
+
105
+ is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
106
+ same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
107
+
108
+ # This is bidirectional attention whenever we are dealing with image tokens
109
+ return is_image_block & same_image_block
110
+
111
+ return inner_mask
112
+
113
+
114
+ # Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
115
+ def create_causal_mask_mapping(
116
+ config: PreTrainedConfig,
117
+ input_embeds: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor],
119
+ cache_position: torch.Tensor,
120
+ past_key_values: Optional[Cache],
121
+ position_ids: Optional[torch.Tensor],
122
+ token_type_ids: Optional[torch.Tensor] = None,
123
+ pixel_values: Optional[torch.FloatTensor] = None,
124
+ is_training: bool = False,
125
+ is_first_iteration: Optional[bool] = None,
126
+ **kwargs,
127
+ ) -> dict:
128
+ """
129
+ Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
130
+ for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
131
+
132
+ Uses `pixel_values` as an optional input to disambiguate edge cases.
133
+ """
134
+ if is_training and token_type_ids is None:
135
+ raise ValueError("`token_type_ids` is required as a model input when training")
136
+
137
+ mask_kwargs = {
138
+ "config": config.get_text_config(),
139
+ "input_embeds": input_embeds,
140
+ "attention_mask": attention_mask,
141
+ "cache_position": cache_position,
142
+ "past_key_values": past_key_values,
143
+ "position_ids": position_ids,
144
+ }
145
+ # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
146
+ # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
147
+ # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
148
+ is_first_iteration = (
149
+ is_first_iteration
150
+ if is_first_iteration is not None
151
+ else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
152
+ )
153
+ if token_type_ids is not None and is_first_iteration:
154
+ # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
155
+ # undo the causal masking)
156
+
157
+ # First find where a new image block starts: 1 if image and previous not image
158
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
159
+ is_image = (token_type_ids == 1).to(cache_position.device)
160
+ is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
161
+ new_image_start = is_image & ~is_previous_image
162
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
163
+ image_group_ids = torch.where(is_image, image_group_ids, -1)
164
+ mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
165
+ token_type_ids.to(cache_position.device), image_group_ids
166
+ )
167
+
168
+ return create_masks_for_generate(**mask_kwargs)
169
+
170
+
72
171
  class GitEmbeddings(nn.Module):
73
172
  """Construct the embeddings from word and position embeddings."""
74
173
 
@@ -148,17 +247,15 @@ class GitSelfAttention(nn.Module):
148
247
  hidden_states: torch.Tensor,
149
248
  attention_mask: Optional[torch.FloatTensor] = None,
150
249
  past_key_values: Optional[Cache] = None,
151
- output_attentions: Optional[bool] = False,
152
- pixel_values_present: Optional[bool] = False,
250
+ cache_position: Optional[torch.Tensor] = None,
153
251
  ) -> tuple[torch.Tensor]:
154
- batch_size, seq_length, _ = hidden_states.shape
252
+ batch_size = hidden_states.shape[0]
155
253
  query_layer = (
156
254
  self.query(hidden_states)
157
255
  .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
158
256
  .transpose(1, 2)
159
257
  )
160
258
 
161
- cutoff = self.image_patch_tokens if pixel_values_present else 0
162
259
  key_layer = (
163
260
  self.key(hidden_states)
164
261
  .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
@@ -170,12 +267,9 @@ class GitSelfAttention(nn.Module):
170
267
  .transpose(1, 2)
171
268
  )
172
269
  if past_key_values is not None:
173
- # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
174
- key_layer_past, value_layer_past = past_key_values.update(
175
- key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
270
+ key_layer, value_layer = past_key_values.update(
271
+ key_layer, value_layer, self.layer_idx, cache_kwargs={"cache_position": cache_position}
176
272
  )
177
- key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
178
- value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
179
273
 
180
274
  # Take the dot product between "query" and "key" to get the raw attention scores.
181
275
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@@ -232,15 +326,14 @@ class GitAttention(nn.Module):
232
326
  hidden_states: torch.Tensor,
233
327
  attention_mask: Optional[torch.FloatTensor] = None,
234
328
  past_key_values: Optional[Cache] = None,
329
+ cache_position: Optional[torch.Tensor] = None,
235
330
  output_attentions: Optional[bool] = False,
236
- pixel_values_present: Optional[bool] = False,
237
331
  ) -> tuple[torch.Tensor]:
238
332
  attn_output, self_attn_weights = self.self(
239
333
  hidden_states,
240
334
  attention_mask,
241
335
  past_key_values,
242
- output_attentions,
243
- pixel_values_present,
336
+ cache_position=cache_position,
244
337
  )
245
338
  attention_output = self.output(attn_output, hidden_states)
246
339
  return attention_output, self_attn_weights
@@ -291,8 +384,8 @@ class GitLayer(GradientCheckpointingLayer):
291
384
  hidden_states: torch.Tensor,
292
385
  attention_mask: Optional[torch.FloatTensor] = None,
293
386
  past_key_values: Optional[Cache] = None,
387
+ cache_position: Optional[torch.Tensor] = None,
294
388
  output_attentions: Optional[bool] = False,
295
- pixel_values_present: Optional[bool] = False,
296
389
  ) -> tuple[torch.Tensor]:
297
390
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
298
391
  attention_output, self_attention_weights = self.attention(
@@ -300,7 +393,7 @@ class GitLayer(GradientCheckpointingLayer):
300
393
  attention_mask,
301
394
  output_attentions=output_attentions,
302
395
  past_key_values=past_key_values,
303
- pixel_values_present=pixel_values_present,
396
+ cache_position=cache_position,
304
397
  )
305
398
 
306
399
  layer_output = apply_chunking_to_forward(
@@ -329,8 +422,8 @@ class GitEncoder(nn.Module):
329
422
  use_cache: Optional[bool] = None,
330
423
  output_attentions: Optional[bool] = False,
331
424
  output_hidden_states: Optional[bool] = False,
332
- pixel_values_present: Optional[bool] = False,
333
425
  return_dict: Optional[bool] = True,
426
+ cache_position: Optional[torch.Tensor] = None,
334
427
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
335
428
  if self.gradient_checkpointing and self.training:
336
429
  if use_cache:
@@ -353,7 +446,7 @@ class GitEncoder(nn.Module):
353
446
  attention_mask,
354
447
  past_key_values,
355
448
  output_attentions,
356
- pixel_values_present,
449
+ cache_position,
357
450
  )
358
451
 
359
452
  hidden_states = layer_outputs[0]
@@ -396,6 +489,7 @@ class GitPreTrainedModel(PreTrainedModel):
396
489
  init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
397
490
  init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
398
491
  init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
492
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
399
493
  if isinstance(module, nn.Linear):
400
494
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
401
495
  if module.bias is not None:
@@ -408,6 +502,8 @@ class GitPreTrainedModel(PreTrainedModel):
408
502
  elif isinstance(module, nn.LayerNorm):
409
503
  init.zeros_(module.bias)
410
504
  init.ones_(module.weight)
505
+ elif isinstance(module, GitEmbeddings):
506
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
411
507
 
412
508
 
413
509
  # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
@@ -827,6 +923,7 @@ class GitVisionModel(GitPreTrainedModel):
827
923
  output_hidden_states: Optional[bool] = None,
828
924
  interpolate_pos_encoding: bool = False,
829
925
  return_dict: Optional[bool] = None,
926
+ **kwargs,
830
927
  ) -> Union[tuple, BaseModelOutput]:
831
928
  r"""
832
929
  Examples:
@@ -902,62 +999,6 @@ class GitModel(GitPreTrainedModel):
902
999
  def set_input_embeddings(self, value):
903
1000
  self.embeddings.word_embeddings = value
904
1001
 
905
- def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
906
- # Default mask is for forward direction. Flip for backward direction.
907
- mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
908
- mask = mask.masked_fill(mask == 1, float("-inf"))
909
- return mask
910
-
911
- def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
912
- num_tgt = tgt.shape[1]
913
- num_memory = memory.shape[1]
914
- device = tgt.device
915
- dtype = tgt.dtype
916
- top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
917
- top_right = torch.full(
918
- (num_memory, num_tgt + past_key_values_length),
919
- float("-inf"),
920
- device=tgt.device,
921
- dtype=dtype,
922
- )
923
- bottom_left = torch.zeros(
924
- (num_tgt, num_memory),
925
- dtype=dtype,
926
- device=tgt_mask.device,
927
- )
928
-
929
- if past_key_values_length > 0:
930
- tgt_mask = torch.zeros(
931
- (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
932
- dtype=dtype,
933
- device=tgt_mask.device,
934
- )
935
-
936
- left = torch.cat((top_left, bottom_left), dim=0)
937
- right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
938
-
939
- full_attention_mask = torch.cat((left, right), dim=1)[None, :]
940
-
941
- if memory_key_padding_mask is None:
942
- memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
943
- # if it is False, it means valid. That is, it is not a padding
944
- if memory_key_padding_mask.dtype != torch.bool:
945
- raise ValueError("Memory key padding mask must be a boolean tensor.")
946
- zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
947
- zero_negative_infinity[memory_key_padding_mask] = float("-inf")
948
- full_attention_mask = full_attention_mask.expand(
949
- (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
950
- )
951
- full_attention_mask = full_attention_mask.clone()
952
- origin_left = full_attention_mask[:, :, :num_memory]
953
- update = zero_negative_infinity[:, None, :]
954
- full_attention_mask[:, :, :num_memory] = origin_left + update
955
-
956
- # add axis for multi-head
957
- full_attention_mask = full_attention_mask[:, None, :, :]
958
-
959
- return full_attention_mask
960
-
961
1002
  @auto_docstring
962
1003
  def forward(
963
1004
  self,
@@ -972,6 +1013,8 @@ class GitModel(GitPreTrainedModel):
972
1013
  output_hidden_states: Optional[bool] = None,
973
1014
  interpolate_pos_encoding: bool = False,
974
1015
  return_dict: Optional[bool] = None,
1016
+ cache_position: Optional[torch.Tensor] = None,
1017
+ **kwargs,
975
1018
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
976
1019
  r"""
977
1020
  Examples:
@@ -1003,15 +1046,6 @@ class GitModel(GitPreTrainedModel):
1003
1046
 
1004
1047
  if input_ids is not None and inputs_embeds is not None:
1005
1048
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1006
- elif input_ids is not None:
1007
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1008
- input_shape = input_ids.size()
1009
- elif inputs_embeds is not None:
1010
- input_shape = inputs_embeds.size()[:-1]
1011
- else:
1012
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1013
-
1014
- seq_length = input_shape[1]
1015
1049
 
1016
1050
  # past_key_values_length
1017
1051
  past_key_values_length = 0
@@ -1022,7 +1056,23 @@ class GitModel(GitPreTrainedModel):
1022
1056
  else past_key_values.get_seq_length()
1023
1057
  )
1024
1058
 
1025
- projected_visual_features = None
1059
+ embedding_output = self.embeddings(
1060
+ input_ids=input_ids,
1061
+ position_ids=position_ids,
1062
+ inputs_embeds=inputs_embeds,
1063
+ past_key_values_length=past_key_values_length,
1064
+ )
1065
+
1066
+ if cache_position is None:
1067
+ cache_position = torch.arange(
1068
+ past_key_values_length,
1069
+ past_key_values_length + embedding_output.shape[1],
1070
+ device=embedding_output.device,
1071
+ )
1072
+
1073
+ # Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
1074
+ token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
1075
+
1026
1076
  if pixel_values is not None:
1027
1077
  if pixel_values.ndim == 4:
1028
1078
  # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
@@ -1048,60 +1098,54 @@ class GitModel(GitPreTrainedModel):
1048
1098
 
1049
1099
  projected_visual_features = self.visual_projection(visual_features)
1050
1100
 
1051
- embedding_output = self.embeddings(
1052
- input_ids=input_ids,
1053
- position_ids=position_ids,
1054
- inputs_embeds=inputs_embeds,
1055
- past_key_values_length=past_key_values_length,
1056
- )
1057
-
1058
- if projected_visual_features is None:
1059
- projected_visual_features = torch.zeros(
1060
- (embedding_output.shape[0], 0, embedding_output.shape[2]),
1061
- dtype=embedding_output.dtype,
1062
- device=embedding_output.device,
1101
+ # Repeat visual features to match embedding batch size.
1102
+ projected_visual_features = projected_visual_features.repeat(
1103
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1063
1104
  )
1064
1105
 
1065
- # Repeat visual features to match embedding batch size.
1066
- projected_visual_features = projected_visual_features.repeat(
1067
- embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1068
- )
1069
-
1070
- # concatenate patch token and text token embeddings
1071
- hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
1072
-
1073
- # By default, an additive causal mask is created
1074
- # for masking the future (one direction).
1075
- tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
1106
+ # concatenate patch token and text token embeddings
1107
+ embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
1108
+ image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
1109
+ token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
1110
+ cache_position = torch.arange(embedding_output.shape[1], device=embedding_output.device, dtype=torch.int)
1111
+ if attention_mask is not None:
1112
+ attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
1113
+ elif past_key_values is not None and input_ids.shape[1] == 1:
1114
+ # Expand attention mask and cache position with image tokens because GIT doesn't add image
1115
+ # placeholder tokens when processing. Doesn't worth the refactor, low usage!
1116
+ cache_position = torch.tensor(
1117
+ [past_key_values_length], dtype=cache_position.dtype, device=cache_position.device
1118
+ )
1119
+ extended_attention_mask = torch.ones(
1120
+ (attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
1121
+ dtype=attention_mask.dtype,
1122
+ device=attention_mask.device,
1123
+ )
1124
+ attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
1076
1125
 
1077
- # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
1078
- combined_attention_mask = self.create_attention_mask(
1079
- tgt=embedding_output,
1080
- memory=projected_visual_features,
1081
- tgt_mask=tgt_mask,
1082
- past_key_values_length=past_key_values_length,
1126
+ # Images attend each other bidirectionally while text remains causal
1127
+ causal_mask = create_causal_mask_mapping(
1128
+ self.config,
1129
+ embedding_output,
1130
+ attention_mask,
1131
+ cache_position,
1132
+ past_key_values,
1133
+ None,
1134
+ token_type_ids,
1135
+ pixel_values,
1083
1136
  )
1084
1137
 
1085
- if attention_mask is not None:
1086
- # if the user provides an attention mask, we add it to the default one
1087
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1088
- expanded_attn_mask = _prepare_4d_attention_mask(
1089
- attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
1090
- ).to(embedding_output.device)
1091
- if past_key_values_length > 0:
1092
- expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
1093
- else:
1094
- combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
1138
+ hidden_states = embedding_output
1095
1139
 
1096
1140
  encoder_outputs = self.encoder(
1097
1141
  hidden_states,
1098
- attention_mask=combined_attention_mask,
1142
+ attention_mask=causal_mask,
1099
1143
  past_key_values=past_key_values,
1100
1144
  use_cache=use_cache,
1101
1145
  output_attentions=output_attentions,
1102
1146
  output_hidden_states=output_hidden_states,
1103
1147
  return_dict=return_dict,
1104
- pixel_values_present=pixel_values is not None,
1148
+ cache_position=cache_position,
1105
1149
  )
1106
1150
  sequence_output = encoder_outputs[0]
1107
1151
 
@@ -1155,6 +1199,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1155
1199
  interpolate_pos_encoding: bool = False,
1156
1200
  return_dict: Optional[bool] = None,
1157
1201
  logits_to_keep: Union[int, torch.Tensor] = 0,
1202
+ cache_position: Optional[torch.Tensor] = None,
1158
1203
  **kwargs,
1159
1204
  ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
1160
1205
  r"""
@@ -1304,6 +1349,7 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1304
1349
  output_hidden_states=output_hidden_states,
1305
1350
  interpolate_pos_encoding=interpolate_pos_encoding,
1306
1351
  return_dict=return_dict,
1352
+ cache_position=cache_position,
1307
1353
  )
1308
1354
 
1309
1355
  hidden_states = outputs[0]
@@ -1337,7 +1383,15 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1337
1383
  )
1338
1384
 
1339
1385
  def prepare_inputs_for_generation(
1340
- self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1386
+ self,
1387
+ input_ids,
1388
+ past_key_values=None,
1389
+ pixel_values=None,
1390
+ attention_mask=None,
1391
+ use_cache=None,
1392
+ cache_position=None,
1393
+ is_first_iteration=False,
1394
+ **kwargs,
1341
1395
  ):
1342
1396
  # Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
1343
1397
 
@@ -1362,11 +1416,14 @@ class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
1362
1416
  model_inputs = {
1363
1417
  "input_ids": input_ids,
1364
1418
  "attention_mask": attention_mask,
1365
- "pixel_values": kwargs.get("pixel_values"),
1366
1419
  "past_key_values": past_key_values,
1367
1420
  "use_cache": use_cache,
1421
+ "cache_position": cache_position,
1368
1422
  }
1369
1423
 
1424
+ if is_first_iteration or not use_cache:
1425
+ model_inputs["pixel_values"] = pixel_values
1426
+
1370
1427
  # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
1371
1428
  for key, value in kwargs.items():
1372
1429
  if key not in model_inputs:
@@ -28,7 +28,7 @@ import torch.nn as nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_layers import (
34
34
  GenericForSequenceClassification,
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_glm import GlmConfig
45
45
 
46
46
 
@@ -79,7 +79,7 @@ class GlmRotaryEmbedding(nn.Module):
79
79
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
80
80
 
81
81
  self.register_buffer("inv_freq", inv_freq, persistent=False)
82
- self.original_inv_freq = inv_freq
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
83
83
 
84
84
  @staticmethod
85
85
  def compute_default_rope_parameters(
@@ -120,7 +120,7 @@ class GlmRotaryEmbedding(nn.Module):
120
120
  position_ids_expanded = position_ids[:, None, :].float()
121
121
 
122
122
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
123
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
123
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
124
124
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
126
  cos = emb.cos() * self.attention_scaling
@@ -216,6 +216,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
216
216
  return q_embed, k_embed
217
217
 
218
218
 
219
+ @use_kernelized_func(apply_rotary_pos_emb)
219
220
  class GlmAttention(nn.Module):
220
221
  """Multi-headed attention from 'Attention Is All You Need' paper"""
221
222
 
@@ -239,7 +240,6 @@ class GlmAttention(nn.Module):
239
240
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
240
241
  )
241
242
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
242
- self.rotary_fn = apply_rotary_pos_emb
243
243
 
244
244
  def forward(
245
245
  self,
@@ -28,7 +28,7 @@ import torch.nn as nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
34
  from ...modeling_layers import (
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
44
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
45
45
  from .configuration_glm4 import Glm4Config
46
46
 
47
47
 
@@ -198,6 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
198
198
  return q_embed, k_embed
199
199
 
200
200
 
201
+ @use_kernelized_func(apply_rotary_pos_emb)
201
202
  class Glm4Attention(nn.Module):
202
203
  """Multi-headed attention from 'Attention Is All You Need' paper"""
203
204
 
@@ -221,7 +222,6 @@ class Glm4Attention(nn.Module):
221
222
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
222
223
  )
223
224
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
224
- self.rotary_fn = apply_rotary_pos_emb
225
225
 
226
226
  def forward(
227
227
  self,
@@ -284,7 +284,7 @@ class Glm4RotaryEmbedding(nn.Module):
284
284
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
285
285
 
286
286
  self.register_buffer("inv_freq", inv_freq, persistent=False)
287
- self.original_inv_freq = inv_freq
287
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
288
288
 
289
289
  @staticmethod
290
290
  def compute_default_rope_parameters(
@@ -325,7 +325,7 @@ class Glm4RotaryEmbedding(nn.Module):
325
325
  position_ids_expanded = position_ids[:, None, :].float()
326
326
 
327
327
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
328
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
328
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
329
329
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
330
330
  emb = torch.cat((freqs, freqs), dim=-1)
331
331
  cos = emb.cos() * self.attention_scaling
@@ -354,7 +354,6 @@ class Glm46VImageProcessor(BaseImageProcessor):
354
354
  image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
355
355
  Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
356
356
  `True`.
357
- The max pixels of the image to resize the image.
358
357
  patch_size (`int`, *optional*, defaults to `self.patch_size`):
359
358
  The spatial patch size of the vision encoder.
360
359
  temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
@@ -381,12 +380,9 @@ class Glm46VImageProcessor(BaseImageProcessor):
381
380
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
382
381
 
383
382
  """
384
- # Try to use config values if set, otherwise fallback to global defaults
385
383
  size = size if size is not None else self.size
386
384
  if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
387
385
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
388
- elif size is None:
389
- size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
390
386
 
391
387
  do_resize = do_resize if do_resize is not None else self.do_resize
392
388
  resample = resample if resample is not None else self.resample
@@ -639,6 +639,7 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
639
639
  pixel_values_videos=None,
640
640
  image_grid_thw=None,
641
641
  video_grid_thw=None,
642
+ is_first_iteration=False,
642
643
  **kwargs,
643
644
  ):
644
645
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -655,13 +656,14 @@ class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
655
656
  image_grid_thw=image_grid_thw,
656
657
  video_grid_thw=video_grid_thw,
657
658
  use_cache=use_cache,
659
+ is_first_iteration=is_first_iteration,
658
660
  **kwargs,
659
661
  )
660
662
 
661
663
  # GLM-4.1V position_ids are prepareed with rope_deltas in forward
662
664
  model_inputs["position_ids"] = None
663
665
 
664
- if cache_position[0] != 0:
666
+ if not is_first_iteration and use_cache:
665
667
  model_inputs["pixel_values"] = None
666
668
  model_inputs["pixel_values_videos"] = None
667
669
 
@@ -110,6 +110,9 @@ class Glm46VPreTrainedModel(Glm4vPreTrainedModel):
110
110
  _can_record_outputs = None
111
111
  _no_split_modules = None
112
112
 
113
+ def _init_weights(self, module):
114
+ raise AttributeError("Not needed")
115
+
113
116
 
114
117
  class Glm46VModel(Glm4vModel):
115
118
  _no_split_modules = None