transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/glmasr/modular_glmasr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_glmasr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import re
23
+ from typing import Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from ...audio_utils import AudioInput, make_list_of_audio
28
+ from ...feature_extraction_utils import BatchFeature
29
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
30
+ from ...tokenization_utils_base import TextInput
31
+ from ...utils import is_torch_available, logging
32
+
33
+
34
+ if is_torch_available():
35
+ import torch
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ class GlmAsrProcessorKwargs(ProcessingKwargs, total=False):
42
+ _defaults = {
43
+ "text_kwargs": {
44
+ "padding": True,
45
+ },
46
+ "audio_kwargs": {
47
+ "sampling_rate": 16000,
48
+ "chunk_length": 30.0,
49
+ "return_attention_mask": True,
50
+ "padding": "max_length",
51
+ },
52
+ "common_kwargs": {
53
+ "return_tensors": "pt",
54
+ "padding_side": "left",
55
+ },
56
+ }
57
+
58
+
59
+ class GlmAsrProcessor(ProcessorMixin):
60
+ r"""
61
+ Constructs an GlmAsr processor which wraps an GlmAsr feature extractor and an GlmAsr
62
+ tokenizer into a single processor.
63
+
64
+ [`GlmAsrProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
65
+ [`Qwen2TokenizerFast`]. See the [`~GlmAsrProcessor.__call__`] for more information.
66
+
67
+ Args:
68
+ feature_extractor ([`WhisperFeatureExtractor`]):
69
+ The feature extractor is a required input.
70
+ tokenizer ([`Qwen2TokenizerFast`]):
71
+ The tokenizer is a required input.
72
+ chat_template (`Optional[str]`, *optional*):
73
+ The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
74
+ template will be used.
75
+ audio_token (`Optional[str]`, *optional*, defaults to `"<|pad|>`"):
76
+ Special token used to represent audio inputs in the chat template.
77
+ default_transcription_prompt (`str`, *optional*, defaults to `"Please transcribe this audio into text"`):
78
+ Default prompt to use for transcription tasks when applying transcription requests.
79
+ max_audio_len (`int`, *optional*, defaults to 655):
80
+ Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
81
+ 655 gives approximately 8192 tokens, corresponding to the maximum sequence length of the text model.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ feature_extractor,
87
+ tokenizer,
88
+ chat_template=None,
89
+ audio_token="<|pad|>",
90
+ default_transcription_prompt="Please transcribe this audio into text",
91
+ max_audio_len=655,
92
+ ):
93
+ self.audio_token = audio_token
94
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
95
+ self.default_transcription_prompt = default_transcription_prompt
96
+ self.max_audio_len = max_audio_len
97
+ super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
98
+
99
+ def _get_audio_token_length(self, audio_lengths: "torch.Tensor") -> "torch.Tensor":
100
+ merge_factor = 4
101
+ for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]:
102
+ audio_lengths = (audio_lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
103
+
104
+ num_tokens = (audio_lengths - merge_factor) // merge_factor + 1
105
+ return num_tokens
106
+
107
+ def __call__(
108
+ self,
109
+ text: Union[TextInput, list[TextInput]],
110
+ audio: Optional[AudioInput] = None,
111
+ output_labels: Optional[bool] = False,
112
+ **kwargs: Unpack[GlmAsrProcessorKwargs],
113
+ ) -> BatchFeature:
114
+ r"""
115
+ Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
116
+ method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
117
+ audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
118
+ with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
119
+ the text is tokenized as-is (LM-only behavior).
120
+
121
+ Args:
122
+ text (`str` or `list[str]`):
123
+ Input sequence or batch of sequences.
124
+ audio (`np.ndarray` or `list[np.ndarray]`):
125
+ Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
126
+ `audio` inputs.
127
+ output_labels (bool, *optional*, default=False):
128
+ Whether to return labels for training.
129
+
130
+ Returns:
131
+ [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
132
+ audio features (`input_features`, `input_features_mask`).
133
+ """
134
+
135
+ # Merge defaults with user kwargs
136
+ call_kwargs = self._merge_kwargs(
137
+ GlmAsrProcessorKwargs,
138
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
139
+ **kwargs,
140
+ )
141
+
142
+ text_kwargs = call_kwargs["text_kwargs"]
143
+ audio_kwargs = call_kwargs["audio_kwargs"]
144
+ return_tensors = text_kwargs.get("return_tensors")
145
+ if return_tensors != "pt":
146
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
147
+
148
+ if isinstance(text, str):
149
+ text = [text]
150
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
151
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
152
+
153
+ audio_inputs = {}
154
+ if audio is not None:
155
+ audio = make_list_of_audio(audio)
156
+ if len(text) != len(audio):
157
+ raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
158
+
159
+ # Determine number of chunks per sample, and flatten
160
+ window_size = int(audio_kwargs["sampling_rate"] * audio_kwargs["chunk_length"])
161
+ max_windows = int(self.max_audio_len // audio_kwargs["chunk_length"])
162
+
163
+ per_sample_windows: list[int] = []
164
+ flat_chunks: list[np.ndarray] = []
165
+
166
+ for audio_el in audio:
167
+ n_samples = int(audio_el.shape[0])
168
+ n_win = max(1, (n_samples + window_size - 1) // window_size)
169
+ if n_win > max_windows:
170
+ logger.warning(
171
+ f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s."
172
+ )
173
+ n_win = max_windows
174
+ per_sample_windows.append(n_win)
175
+
176
+ time_cap = min(n_samples, n_win * window_size)
177
+ for i in range(n_win):
178
+ start = i * window_size
179
+ end = min((i + 1) * window_size, time_cap)
180
+ flat_chunks.append(audio_el[start:end])
181
+
182
+ # Feature extraction
183
+ audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
184
+ padding_mask = audio_inputs.pop("attention_mask")
185
+ audio_inputs["input_features_mask"] = padding_mask
186
+
187
+ # Compute sequence lengths token counting
188
+ audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
189
+ audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
190
+
191
+ # expand audio tokens in text
192
+ for i, audio_length in enumerate(audio_tokens_lengths):
193
+ expanded = re.sub(re.escape(self.audio_token), self.audio_token * audio_length, text[i])
194
+ text[i] = expanded
195
+
196
+ # Tokenize
197
+ text_inputs = self.tokenizer(text, **text_kwargs)
198
+
199
+ data = {**text_inputs, **audio_inputs}
200
+ if output_labels:
201
+ labels = data["input_ids"].clone()
202
+ labels[labels == self.audio_token_id] = -100
203
+ labels[labels == self.tokenizer.pad_token_id] = -100
204
+ data["labels"] = labels
205
+
206
+ return BatchFeature(data=data, tensor_type=return_tensors)
207
+
208
+ @property
209
+ def model_input_names(self) -> list[str]:
210
+ tok_names = self.tokenizer.model_input_names
211
+ fea_names = self.feature_extractor.model_input_names
212
+ return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
213
+
214
+ def apply_transcription_request(
215
+ self,
216
+ audio: Union[str, list[str], AudioInput],
217
+ prompt: Optional[Union[str, list[str]]] = None,
218
+ **kwargs: Unpack[GlmAsrProcessorKwargs],
219
+ ) -> BatchFeature:
220
+ """
221
+ Prepare inputs for automatic speech recognition without manually writing the default transcription prompt.
222
+
223
+ Args:
224
+ audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
225
+ Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by
226
+ the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly.
227
+ prompt (`str` or `list[str]`, *optional*):
228
+ Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`,
229
+ each sample uses `"Transcribe the input speech."`.
230
+ **kwargs:
231
+ Additional keyword arguments forwarded to [`~AudioFlamingo3Processor.apply_chat_template`] (for example
232
+ `text_kwargs`, `audio_kwargs`, ...).
233
+
234
+ Returns:
235
+ [`BatchFeature`]: Processor outputs ready to be passed to [`AudioFlamingo3ForConditionalGeneration.generate`].
236
+
237
+ """
238
+
239
+ if isinstance(audio, str):
240
+ audio_items: list[Union[str, np.ndarray]] = [audio]
241
+ elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio):
242
+ audio_items = list(audio)
243
+ else:
244
+ audio_items = list(make_list_of_audio(audio))
245
+ if is_torch_available():
246
+ audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items]
247
+
248
+ batch_size = len(audio_items)
249
+ if batch_size == 0:
250
+ raise ValueError("`audio` must contain at least one sample.")
251
+
252
+ if prompt is None:
253
+ prompts = [self.default_transcription_prompt] * batch_size
254
+ elif isinstance(prompt, str):
255
+ prompts = [prompt] * batch_size
256
+ elif isinstance(prompt, (list, tuple)):
257
+ if len(prompt) != batch_size:
258
+ raise ValueError(
259
+ f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match."
260
+ )
261
+ prompts = []
262
+ for item in prompt:
263
+ if item is None:
264
+ prompts.append(self.default_transcription_prompt)
265
+ elif isinstance(item, str):
266
+ prompts.append(item)
267
+ else:
268
+ raise TypeError("Each prompt must be a string or `None`.")
269
+ else:
270
+ raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.")
271
+
272
+ conversations = [
273
+ [
274
+ {
275
+ "role": "user",
276
+ "content": [
277
+ {"type": "audio", "path": audio_item}
278
+ if isinstance(audio_item, str)
279
+ else {"type": "audio", "audio": audio_item},
280
+ {"type": "text", "text": prompt_text},
281
+ ],
282
+ }
283
+ ]
284
+ for prompt_text, audio_item in zip(prompts, audio_items)
285
+ ]
286
+
287
+ return self.apply_chat_template(
288
+ conversations,
289
+ tokenize=True,
290
+ add_generation_prompt=True,
291
+ return_dict=True,
292
+ **kwargs,
293
+ )
294
+
295
+ def batch_decode(self, *args, strip_prefix=False, **kwargs):
296
+ """
297
+ Forward arguments to [`~PreTrainedTokenizer.batch_decode`] and optionally remove the assistant framing the model
298
+ was trained to produce.
299
+
300
+ AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`.
301
+ Setting `strip_prefix=True` trims the fixed prefix for just the transcription text.
302
+ """
303
+ decoded = self.tokenizer.batch_decode(*args, **kwargs)
304
+ if strip_prefix:
305
+ decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded]
306
+ return decoded
307
+
308
+ def _strip_assistant_prefix_and_quotes(self, text: str) -> str:
309
+ """
310
+ Remove the assistant prefix and surrounding quotes from a decoded transcription string.
311
+ """
312
+
313
+ stripped = text.strip()
314
+
315
+ for prefix in (
316
+ "The spoken content of the audio is",
317
+ "The transcription of the audio is",
318
+ ):
319
+ if stripped.startswith(prefix):
320
+ stripped = stripped[len(prefix) :].strip()
321
+ break
322
+
323
+ if stripped.endswith("."):
324
+ stripped = stripped[:-1].strip()
325
+
326
+ if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}:
327
+ stripped = stripped[1:-1].strip()
328
+
329
+ return stripped
330
+
331
+
332
+ __all__ = ["GlmAsrProcessor"]
@@ -107,7 +107,6 @@ class GLPNImageProcessorFast(BaseImageProcessorFast):
107
107
  processed_groups[shape] = stacked_images
108
108
 
109
109
  processed_images = reorder_images(processed_groups, grouped_index)
110
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
111
110
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
112
111
 
113
112
  def post_process_depth_estimation(self, outputs, target_sizes=None):
@@ -189,7 +189,6 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
189
189
  processed_images_grouped[shape] = stacked_images
190
190
 
191
191
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
192
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
193
192
 
194
193
  return BatchFeature(
195
194
  data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
@@ -433,6 +433,7 @@ class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
433
433
  self.neck = GotOcr2VisionNeck(config)
434
434
 
435
435
  self.gradient_checkpointing = False
436
+ self.post_init()
436
437
 
437
438
  def get_input_embeddings(self):
438
439
  return self.patch_embed
@@ -796,6 +797,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
796
797
  attention_mask=None,
797
798
  cache_position=None,
798
799
  logits_to_keep=None,
800
+ is_first_iteration=False,
799
801
  **kwargs,
800
802
  ):
801
803
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -807,12 +809,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
807
809
  attention_mask=attention_mask,
808
810
  cache_position=cache_position,
809
811
  logits_to_keep=logits_to_keep,
812
+ is_first_iteration=is_first_iteration,
810
813
  **kwargs,
811
814
  )
812
815
 
813
- if cache_position[0] == 0:
814
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
815
- # Otherwise we need pixel values to be passed to model
816
+ if is_first_iteration or not kwargs.get("use_cache", True):
817
+ # Pixel values are used only in the first iteration if available
818
+ # In subsquent iterations, they are already merged with text and cached
819
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
820
+ # iteration with a question and cached system prompt (continue generate from cache)
816
821
  model_inputs["pixel_values"] = pixel_values
817
822
 
818
823
  return model_inputs
@@ -103,7 +103,6 @@ class GPT2Attention(nn.Module):
103
103
  ),
104
104
  persistent=False,
105
105
  )
106
- self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
107
106
 
108
107
  self.embed_dim = config.hidden_size
109
108
  self.num_heads = config.num_attention_heads
@@ -476,12 +475,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
476
475
  _supports_flash_attn = True
477
476
  _supports_sdpa = True
478
477
  _supports_attention_backend = True
479
-
480
478
  _can_compile_fullgraph = True
481
479
 
482
- def __init__(self, *inputs, **kwargs):
483
- super().__init__(*inputs, **kwargs)
484
-
485
480
  @torch.no_grad()
486
481
  def _init_weights(self, module):
487
482
  """Initialize the weights."""
@@ -497,6 +492,14 @@ class GPT2PreTrainedModel(PreTrainedModel):
497
492
  elif isinstance(module, nn.LayerNorm):
498
493
  init.zeros_(module.bias)
499
494
  init.ones_(module.weight)
495
+ elif isinstance(module, GPT2Attention):
496
+ max_positions = module.config.max_position_embeddings
497
+ init.copy_(
498
+ module.bias,
499
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
500
+ 1, 1, max_positions, max_positions
501
+ ),
502
+ )
500
503
 
501
504
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
502
505
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
@@ -26,7 +26,6 @@ from ...activations import ACT2FN
26
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
27
  from ...generation import GenerationMixin
28
28
  from ...masking_utils import create_causal_mask
29
- from ...modeling_flash_attention_utils import is_flash_attn_available
30
29
  from ...modeling_layers import GradientCheckpointingLayer
31
30
  from ...modeling_outputs import (
32
31
  BaseModelOutputWithPastAndCrossAttentions,
@@ -43,10 +42,6 @@ from ...utils import (
43
42
  from .configuration_gpt_bigcode import GPTBigCodeConfig
44
43
 
45
44
 
46
- if is_flash_attn_available():
47
- pass
48
-
49
-
50
45
  logger = logging.get_logger(__name__)
51
46
 
52
47
 
@@ -360,9 +355,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
360
355
  _supports_flash_attn = True
361
356
  _supports_sdpa = True
362
357
 
363
- def __init__(self, *inputs, **kwargs):
364
- super().__init__(*inputs, **kwargs)
365
-
366
358
  @torch.no_grad()
367
359
  def _init_weights(self, module):
368
360
  """Initialize the weights."""
@@ -377,6 +369,9 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
377
369
  init.normal_(
378
370
  module.c_proj.weight, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
379
371
  )
372
+ elif isinstance(module, GPTBigCodeModel):
373
+ max_positions = module.config.max_position_embeddings
374
+ init.copy_(module.bias, torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)))
380
375
 
381
376
 
382
377
  @auto_docstring
@@ -20,6 +20,7 @@ import torch
20
20
  from torch import nn
21
21
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
22
 
23
+ from ... import initialization as init
23
24
  from ...activations import ACT2FN
24
25
  from ...cache_utils import Cache, DynamicCache
25
26
  from ...generation import GenerationMixin
@@ -70,11 +71,11 @@ class GPTNeoSelfAttention(nn.Module):
70
71
  # local causal self attention is a sliding window where each token can only attend to the previous
71
72
  # window_size tokens. This is implemented by updating the causal mask such that for each token
72
73
  # all other tokens are masked except the previous window_size tokens.
74
+ self.attention_type = attention_type
73
75
  if attention_type == "local":
74
76
  bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
75
77
 
76
78
  self.register_buffer("bias", bias, persistent=False)
77
- self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
78
79
 
79
80
  self.attn_dropout = nn.Dropout(float(config.attention_dropout))
80
81
  self.resid_dropout = nn.Dropout(float(config.resid_dropout))
@@ -237,8 +238,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
237
238
  else torch.get_autocast_gpu_dtype()
238
239
  )
239
240
  # Handle the case where the model is quantized
240
- elif hasattr(self.config, "_pre_quantization_dtype"):
241
- target_dtype = self.config._pre_quantization_dtype
241
+ elif hasattr(self.config, "quantization_config"):
242
+ target_dtype = self.config.dtype
242
243
  else:
243
244
  target_dtype = self.q_proj.weight.dtype
244
245
 
@@ -382,6 +383,17 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
382
383
  _supports_flash_attn = True
383
384
  _can_compile_fullgraph = False # TODO: needs a hybrid cache
384
385
 
386
+ def _init_weights(self, module):
387
+ super()._init_weights(module)
388
+ if isinstance(module, GPTNeoSelfAttention):
389
+ max_positions = module.config.max_position_embeddings
390
+ bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
391
+ 1, 1, max_positions, max_positions
392
+ )
393
+ if module.attention_type == "local":
394
+ bias = torch.bitwise_xor(bias, torch.tril(bias, -module.config.window_size))
395
+ init.copy_(module.bias, bias)
396
+
385
397
 
386
398
  @auto_docstring
387
399
  class GPTNeoModel(GPTNeoPreTrainedModel):
@@ -66,7 +66,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
66
66
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
67
67
 
68
68
  self.register_buffer("inv_freq", inv_freq, persistent=False)
69
- self.original_inv_freq = inv_freq
69
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
70
70
 
71
71
  @staticmethod
72
72
  def compute_default_rope_parameters(
@@ -78,7 +78,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
78
78
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
79
79
 
80
80
  self.register_buffer("inv_freq", inv_freq, persistent=False)
81
- self.original_inv_freq = inv_freq
81
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
82
82
 
83
83
  @staticmethod
84
84
  def compute_default_rope_parameters(
@@ -117,5 +117,22 @@ class GptOssConfig(PreTrainedConfig):
117
117
  **kwargs,
118
118
  )
119
119
 
120
+ def __setattr__(self, key, value):
121
+ """
122
+ Overwritten to allow checking for the proper attention implementation to be used.
123
+
124
+ Due to `set_attn_implementation` which internally assigns `_attn_implementation_internal = "..."`, simply overwriting
125
+ the specific attention setter is not enough. Using a property/setter for `_attn_implementation_internal` would result in
126
+ a recursive dependency (as `_attn_implementation` acts as a wrapper around `_attn_implementation_internal`) - hence, this
127
+ workaround.
128
+ """
129
+ if key in ("_attn_implementation", "_attn_implementation_internal"):
130
+ if value and "flash" in value and value.removeprefix("paged|") != "kernels-community/vllm-flash-attn3":
131
+ raise ValueError(
132
+ f"GPT-OSS model does not support the specified flash attention implementation: {value}. "
133
+ "Only `kernels-community/vllm-flash-attn3` is supported."
134
+ )
135
+ super().__setattr__(key, value)
136
+
120
137
 
121
138
  __all__ = ["GptOssConfig"]
@@ -28,8 +28,7 @@ from torch.nn import functional as F
28
28
  from ... import initialization as init
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernelized_func
32
- from ...integrations.hub_kernels import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
33
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
33
  from ...modeling_layers import (
35
34
  GenericForSequenceClassification,
@@ -89,8 +88,8 @@ class GptOssExperts(nn.Module):
89
88
 
90
89
  Args:
91
90
  hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
92
- selected_experts (torch.Tensor): (batch_size * token_num, top_k)
93
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91
+ selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
92
+ routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
94
93
  Returns:
95
94
  torch.Tensor
96
95
  """
@@ -160,8 +159,8 @@ class GptOssTopKRouter(nn.Module):
160
159
 
161
160
  def forward(self, hidden_states):
162
161
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
163
- router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
164
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
162
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
163
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
165
164
  router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
166
165
  router_scores = router_top_value
167
166
  return router_logits, router_scores, router_indices
@@ -197,7 +196,7 @@ class GptOssRotaryEmbedding(nn.Module):
197
196
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
198
197
 
199
198
  self.register_buffer("inv_freq", inv_freq, persistent=False)
200
- self.original_inv_freq = inv_freq
199
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
201
200
 
202
201
  @staticmethod
203
202
  def compute_default_rope_parameters(
@@ -445,8 +444,6 @@ class GptOssPreTrainedModel(PreTrainedModel):
445
444
  "attentions": GptOssAttention,
446
445
  }
447
446
  _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
448
- _supports_flash_attention = False
449
- _supports_flex_attention = False
450
447
 
451
448
  @torch.no_grad()
452
449
  def _init_weights(self, module):
@@ -21,7 +21,7 @@ from torch.nn import functional as F
21
21
 
22
22
  from ... import initialization as init
23
23
  from ...cache_utils import Cache, DynamicCache
24
- from ...integrations.hub_kernels import use_kernel_forward_from_hub
24
+ from ...integrations import use_kernel_forward_from_hub
25
25
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
26
26
  from ...modeling_outputs import (
27
27
  MoeModelOutputWithPast,
@@ -86,8 +86,8 @@ class GptOssExperts(nn.Module):
86
86
 
87
87
  Args:
88
88
  hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
89
- selected_experts (torch.Tensor): (batch_size * token_num, top_k)
90
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
89
+ selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
90
+ routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
91
91
  Returns:
92
92
  torch.Tensor
93
93
  """
@@ -157,8 +157,8 @@ class GptOssTopKRouter(nn.Module):
157
157
 
158
158
  def forward(self, hidden_states):
159
159
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
160
- router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
161
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
160
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
161
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
162
162
  router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
163
163
  router_scores = router_top_value
164
164
  return router_logits, router_scores, router_indices
@@ -354,8 +354,6 @@ class GptOssDecoderLayer(LlamaDecoderLayer):
354
354
  class GptOssPreTrainedModel(LlamaPreTrainedModel):
355
355
  _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
356
356
  _supports_sdpa = False
357
- _supports_flash_attention = False
358
- _supports_flex_attention = False
359
357
  _can_record_outputs = {
360
358
  "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
361
359
  "hidden_states": GptOssDecoderLayer,