transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,14 @@
14
14
  # limitations under the License.
15
15
  """PyTorch GPT-J model."""
16
16
 
17
+ import math
17
18
  from typing import Optional, Union
18
19
 
19
20
  import torch
20
21
  from torch import nn
21
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
23
 
24
+ from ... import initialization as init
23
25
  from ...activations import ACT2FN
24
26
  from ...cache_utils import Cache, DynamicCache
25
27
  from ...generation import GenerationMixin
@@ -77,7 +79,7 @@ class GPTJAttention(nn.Module):
77
79
  def __init__(self, config, layer_idx=None):
78
80
  super().__init__()
79
81
  self.config = config
80
- max_positions = config.max_position_embeddings
82
+ self.max_positions = config.max_position_embeddings
81
83
 
82
84
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
83
85
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
@@ -99,15 +101,17 @@ class GPTJAttention(nn.Module):
99
101
  f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
100
102
  f" `num_attention_heads`: {self.num_attention_heads})."
101
103
  )
102
- self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
104
+ self.scale_attn = math.sqrt(self.head_dim)
103
105
 
104
106
  self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
105
107
  self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
106
108
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
107
109
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
108
110
  self.rotary_dim = config.rotary_dim
109
- pos_embd_dim = self.rotary_dim or self.embed_dim
110
- self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
111
+ self.pos_embd_dim = self.rotary_dim or self.embed_dim
112
+ self.register_buffer(
113
+ "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
114
+ )
111
115
 
112
116
  def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
113
117
  """
@@ -334,8 +338,8 @@ class GPTJFlashAttention2(GPTJAttention):
334
338
  else torch.get_autocast_gpu_dtype()
335
339
  )
336
340
  # Handle the case where the model is quantized
337
- elif hasattr(self.config, "_pre_quantization_dtype"):
338
- target_dtype = self.config._pre_quantization_dtype
341
+ elif hasattr(self.config, "quantization_config"):
342
+ target_dtype = self.config.dtype
339
343
  else:
340
344
  target_dtype = self.q_proj.weight.dtype
341
345
 
@@ -444,6 +448,11 @@ class GPTJPreTrainedModel(PreTrainedModel):
444
448
  _supports_flash_attn = True
445
449
  _can_compile_fullgraph = True
446
450
 
451
+ def _init_weights(self, module):
452
+ super()._init_weights(module)
453
+ if isinstance(module, GPTJAttention):
454
+ init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
455
+
447
456
 
448
457
  @auto_docstring
449
458
  class GPTJModel(GPTJPreTrainedModel):
@@ -337,7 +337,7 @@ class GraniteRotaryEmbedding(nn.Module):
337
337
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
338
338
 
339
339
  self.register_buffer("inv_freq", inv_freq, persistent=False)
340
- self.original_inv_freq = inv_freq
340
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
341
341
 
342
342
  @staticmethod
343
343
  def compute_default_rope_parameters(
@@ -293,6 +293,12 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
293
293
  super()._init_weights(module)
294
294
  if isinstance(module, GraniteSpeechEncoderProjector):
295
295
  init.normal_(module.query)
296
+ elif isinstance(module, GraniteSpeechCTCEncoder):
297
+ context_size = module.config.context_size
298
+ seq = torch.arange(context_size)
299
+ relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
300
+ attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + module.config.max_pos_emb
301
+ init.copy_(module.attention_dists, attention_dists)
296
302
 
297
303
 
298
304
  @auto_docstring(
@@ -322,6 +328,12 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
322
328
 
323
329
  self.post_init()
324
330
 
331
+ def set_decoder(self, decoder):
332
+ self.language_model.set_decoder(decoder)
333
+
334
+ def get_decoder(self):
335
+ return self.language_model.get_decoder()
336
+
325
337
  def set_input_embeddings(self, value):
326
338
  self.language_model.set_input_embeddings(value)
327
339
 
@@ -458,6 +470,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
458
470
  attention_mask=None,
459
471
  cache_position=None,
460
472
  logits_to_keep=None,
473
+ is_first_iteration=False,
461
474
  **kwargs,
462
475
  ):
463
476
  # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
@@ -469,13 +482,14 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
469
482
  attention_mask=attention_mask,
470
483
  cache_position=cache_position,
471
484
  logits_to_keep=logits_to_keep,
485
+ is_first_iteration=is_first_iteration,
472
486
  **kwargs,
473
487
  )
474
488
 
475
489
  # If we're in cached decoding stage, input_features should be None because
476
490
  # input ids do not contain special audio token anymore Otherwise we need
477
491
  # input feature values to be passed to the model
478
- if cache_position[0] == 0:
492
+ if is_first_iteration or not kwargs.get("use_cache", True):
479
493
  model_inputs["input_features"] = input_features
480
494
  return model_inputs
481
495
 
@@ -80,7 +80,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
80
80
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
81
 
82
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
83
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
84
 
85
85
  @staticmethod
86
86
  def compute_default_rope_parameters(
@@ -456,8 +456,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
456
456
  _supports_flash_attn = True
457
457
  _supports_sdpa = True
458
458
  _supports_flex_attn = True
459
-
460
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
459
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
461
460
  _supports_attention_backend = True
462
461
  _can_record_outputs = {
463
462
  "hidden_states": GraniteMoeDecoderLayer,
@@ -146,8 +146,7 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
146
146
  _skip_keys_device_placement = ["past_key_values"]
147
147
  _supports_flash_attn = True
148
148
  _supports_sdpa = True
149
-
150
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
149
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
151
150
 
152
151
  @torch.no_grad()
153
152
  def _init_weights(self, module):
@@ -92,6 +92,8 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
92
92
  allow the model to output the auxiliary loss.
93
93
  router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient
94
94
  shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
95
+ position_embedding_type (`str`, *optional*):
96
+ Positional embedding type to be used; defaults to None. Allowed options: `[None, "rope"]`
95
97
  layer_types (`List`, *optional*): list of strings to be used as layer types.
96
98
  Allowed choices: "mamba", "attention".
97
99
  mamba_n_heads (`int`, *optional*, defaults to 128):
@@ -159,6 +161,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
159
161
  output_router_logits: Optional[bool] = False,
160
162
  router_aux_loss_coef: Optional[float] = 0.001,
161
163
  shared_intermediate_size: Optional[int] = 1024,
164
+ position_embedding_type: Optional[str] = None,
162
165
  layer_types: Optional[list[str]] = None,
163
166
  mamba_n_heads: Optional[int] = 128,
164
167
  mamba_n_groups: Optional[int] = 1,
@@ -198,6 +201,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
198
201
  self.output_router_logits = output_router_logits
199
202
  self.router_aux_loss_coef = router_aux_loss_coef
200
203
  self.shared_intermediate_size = shared_intermediate_size
204
+ self.position_embedding_type = position_embedding_type
201
205
  self.rope_parameters = rope_parameters
202
206
 
203
207
  mamba_intermediate = mamba_expand * hidden_size
@@ -31,7 +31,12 @@ from transformers.activations import ACT2FN
31
31
  from ... import initialization as init
32
32
  from ...cache_utils import Cache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
+ from ...integrations import (
35
+ lazy_load_kernel,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask
36
41
  from ...modeling_layers import GradientCheckpointingLayer
37
42
  from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -40,22 +45,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
45
  from ...processing_utils import Unpack
41
46
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
47
  from ...utils.generic import check_model_inputs, maybe_autocast
43
- from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
44
48
  from .configuration_granitemoehybrid import GraniteMoeHybridConfig
45
49
 
46
50
 
47
- if is_mamba_2_ssm_available():
48
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
49
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
50
- else:
51
- selective_state_update = None
52
-
53
- if is_causal_conv1d_available():
54
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
- else:
56
- causal_conv1d_update, causal_conv1d_fn = None, None
57
-
58
-
59
51
  logger = logging.get_logger(__name__)
60
52
 
61
53
 
@@ -165,6 +157,7 @@ class GraniteMoeHybridAttention(nn.Module):
165
157
  attention_mask: Optional[torch.Tensor],
166
158
  past_key_values: Optional[Cache] = None,
167
159
  cache_position: Optional[torch.LongTensor] = None,
160
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
168
161
  **kwargs: Unpack[TransformersKwargs],
169
162
  ) -> tuple[torch.Tensor, torch.Tensor]:
170
163
  input_shape = hidden_states.shape[:-1]
@@ -174,6 +167,10 @@ class GraniteMoeHybridAttention(nn.Module):
174
167
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
175
168
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
169
 
170
+ if position_embeddings is not None:
171
+ cos, sin = position_embeddings
172
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
173
+
177
174
  if past_key_values is not None:
178
175
  cache_kwargs = {"cache_position": cache_position}
179
176
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -371,9 +368,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
371
368
  return hidden_states
372
369
 
373
370
 
374
- is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
375
-
376
-
377
371
  # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
378
372
  class GraniteMoeHybridMambaLayer(nn.Module):
379
373
  """
@@ -445,6 +439,20 @@ class GraniteMoeHybridMambaLayer(nn.Module):
445
439
 
446
440
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
447
441
 
442
+ global causal_conv1d_update, causal_conv1d_fn
443
+ causal_conv1d = lazy_load_kernel("causal-conv1d")
444
+ causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
445
+ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
446
+
447
+ global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
448
+ mamba_ssm = lazy_load_kernel("mamba-ssm")
449
+ selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
450
+ mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
451
+ mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
452
+
453
+ global is_fast_path_available
454
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
455
+
448
456
  if not is_fast_path_available:
449
457
  logger.warning_once(
450
458
  "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
@@ -915,7 +923,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
915
923
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
916
924
 
917
925
  self.register_buffer("inv_freq", inv_freq, persistent=False)
918
- self.original_inv_freq = inv_freq
926
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
919
927
 
920
928
  @staticmethod
921
929
  def compute_default_rope_parameters(
@@ -1231,8 +1239,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
1231
1239
  _supports_flash_attn = True
1232
1240
  _supports_sdpa = True
1233
1241
  _supports_flex_attn = True
1234
-
1235
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
1242
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
1236
1243
  _supports_attention_backend = True
1237
1244
  _can_record_outputs = {
1238
1245
  "hidden_states": GraniteMoeHybridDecoderLayer,
@@ -1265,7 +1272,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
1265
1272
  [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1266
1273
  )
1267
1274
  self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
- self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config=config)
1275
+ self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
1269
1276
  self.gradient_checkpointing = False
1270
1277
  self.embedding_multiplier = config.embedding_multiplier
1271
1278
 
@@ -1313,7 +1320,9 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
1313
1320
 
1314
1321
  # embed positions
1315
1322
  hidden_states = inputs_embeds
1316
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
1323
+ position_embeddings = None
1324
+ if self.rotary_emb is not None:
1325
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1317
1326
 
1318
1327
  for decoder_layer in self.layers:
1319
1328
  # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
@@ -1547,6 +1556,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1547
1556
  cache_position=None,
1548
1557
  position_ids=None,
1549
1558
  use_cache=True,
1559
+ is_first_iteration=False,
1550
1560
  **kwargs,
1551
1561
  ):
1552
1562
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -1579,7 +1589,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1579
1589
  position_ids = position_ids[:, -input_ids.shape[1] :]
1580
1590
 
1581
1591
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1582
- if inputs_embeds is not None and empty_past_kv:
1592
+ if inputs_embeds is not None and is_first_iteration:
1583
1593
  model_inputs = {"inputs_embeds": inputs_embeds}
1584
1594
  else:
1585
1595
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -39,6 +39,7 @@ from ..granitemoeshared.modeling_granitemoeshared import (
39
39
  GraniteMoeSharedModel,
40
40
  GraniteMoeSharedMoE,
41
41
  GraniteMoeSharedPreTrainedModel,
42
+ apply_rotary_pos_emb,
42
43
  eager_attention_forward,
43
44
  )
44
45
  from .configuration_granitemoehybrid import GraniteMoeHybridConfig
@@ -57,6 +58,7 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
57
58
  attention_mask: Optional[torch.Tensor],
58
59
  past_key_values: Optional[Cache] = None,
59
60
  cache_position: Optional[torch.LongTensor] = None,
61
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
60
62
  **kwargs: Unpack[TransformersKwargs],
61
63
  ) -> tuple[torch.Tensor, torch.Tensor]:
62
64
  input_shape = hidden_states.shape[:-1]
@@ -66,6 +68,10 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
66
68
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
67
69
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
68
70
 
71
+ if position_embeddings is not None:
72
+ cos, sin = position_embeddings
73
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
74
+
69
75
  if past_key_values is not None:
70
76
  cache_kwargs = {"cache_position": cache_position}
71
77
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -203,6 +209,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
203
209
  [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
204
210
  )
205
211
  self.embedding_multiplier = config.embedding_multiplier
212
+ self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
206
213
 
207
214
  @auto_docstring
208
215
  @check_model_inputs
@@ -245,7 +252,9 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
245
252
 
246
253
  # embed positions
247
254
  hidden_states = inputs_embeds
248
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
255
+ position_embeddings = None
256
+ if self.rotary_emb is not None:
257
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
249
258
 
250
259
  for decoder_layer in self.layers:
251
260
  # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
@@ -300,6 +309,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
300
309
  cache_position=None,
301
310
  position_ids=None,
302
311
  use_cache=True,
312
+ is_first_iteration=False,
303
313
  **kwargs,
304
314
  ):
305
315
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -332,7 +342,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
332
342
  position_ids = position_ids[:, -input_ids.shape[1] :]
333
343
 
334
344
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
- if inputs_embeds is not None and empty_past_kv:
345
+ if inputs_embeds is not None and is_first_iteration:
336
346
  model_inputs = {"inputs_embeds": inputs_embeds}
337
347
  else:
338
348
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -462,8 +462,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
462
462
  _supports_flash_attn = True
463
463
  _supports_sdpa = True
464
464
  _supports_flex_attn = True
465
-
466
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
465
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
467
466
  _supports_attention_backend = True
468
467
  _can_record_outputs = {
469
468
  "hidden_states": GraniteMoeSharedDecoderLayer,
@@ -494,7 +493,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
494
493
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
495
494
 
496
495
  self.register_buffer("inv_freq", inv_freq, persistent=False)
497
- self.original_inv_freq = inv_freq
496
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
498
497
 
499
498
  @staticmethod
500
499
  def compute_default_rope_parameters(
@@ -34,7 +34,7 @@ class GroundingDinoConfig(PreTrainedConfig):
34
34
  documentation from [`PreTrainedConfig`] for more information.
35
35
 
36
36
  Args:
37
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
37
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
38
38
  The configuration of the backbone model.
39
39
  backbone (`str`, *optional*):
40
40
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -285,9 +285,8 @@ class GroundingDinoConfig(PreTrainedConfig):
285
285
  self.positional_embedding_temperature = positional_embedding_temperature
286
286
  self.init_std = init_std
287
287
  self.layer_norm_eps = layer_norm_eps
288
+
288
289
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
289
- self.tie_encoder_decoder = True
290
- self.tie_encoder_decoder = True
291
290
 
292
291
 
293
292
  __all__ = ["GroundingDinoConfig"]
@@ -1415,7 +1415,7 @@ class GroundingDinoPreTrainedModel(PreTrainedModel):
1415
1415
  elif isinstance(module, GroundingDinoFusionLayer):
1416
1416
  init.constant_(module.vision_param, 1e-4)
1417
1417
  init.constant_(module.text_param, 1e-4)
1418
- elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
1418
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
1419
1419
  init.normal_(module.weight, mean=0.0, std=std)
1420
1420
  if module.bias is not None:
1421
1421
  init.zeros_(module.bias)
@@ -1511,7 +1511,7 @@ class GroundingDinoEncoder(GroundingDinoPreTrainedModel):
1511
1511
  output_hidden_states=None,
1512
1512
  return_dict=None,
1513
1513
  **kwargs,
1514
- ):
1514
+ ) -> Union[tuple, GroundingDinoEncoderOutput]:
1515
1515
  r"""
1516
1516
  Args:
1517
1517
  vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1666,7 +1666,7 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
1666
1666
  output_hidden_states=None,
1667
1667
  return_dict=None,
1668
1668
  **kwargs,
1669
- ):
1669
+ ) -> Union[tuple, GroundingDinoDecoderOutput]:
1670
1670
  r"""
1671
1671
  Args:
1672
1672
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
@@ -2059,7 +2059,7 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
2059
2059
  output_hidden_states=None,
2060
2060
  return_dict=None,
2061
2061
  **kwargs,
2062
- ):
2062
+ ) -> Union[tuple, GroundingDinoModelOutput]:
2063
2063
  r"""
2064
2064
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
2065
2065
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -758,14 +758,19 @@ class GroupViTPreTrainedModel(PreTrainedModel):
758
758
  init.normal_(module.weight, mean=0.0, std=init_range)
759
759
  if module.bias is not None:
760
760
  init.zeros_(module.bias)
761
- elif isinstance(module, nn.LayerNorm):
761
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
762
762
  init.zeros_(module.bias)
763
763
  init.ones_(module.weight)
764
+ if getattr(module, "running_mean", None) is not None:
765
+ init.zeros_(module.running_mean)
766
+ init.ones_(module.running_var)
767
+ init.zeros_(module.num_batches_tracked)
764
768
 
765
769
  factor = self.config.initializer_factor
766
770
  if isinstance(module, GroupViTTextEmbeddings):
767
771
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
768
772
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
773
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
769
774
  elif isinstance(module, GroupViTAttention):
770
775
  factor = self.config.initializer_factor
771
776
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -79,7 +79,7 @@ class HeliumRotaryEmbedding(nn.Module):
79
79
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
80
80
 
81
81
  self.register_buffer("inv_freq", inv_freq, persistent=False)
82
- self.original_inv_freq = inv_freq
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
83
83
 
84
84
  @staticmethod
85
85
  def compute_default_rope_parameters(
@@ -26,6 +26,7 @@ import torch
26
26
  import torch.nn.functional as F
27
27
  from torch import Tensor, nn
28
28
 
29
+ from ... import initialization as init
29
30
  from ...activations import ACT2FN
30
31
  from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
31
32
  from ...modeling_utils import PreTrainedModel
@@ -45,6 +46,15 @@ class HGNetV2PreTrainedModel(PreTrainedModel):
45
46
  input_modalities = ("image",)
46
47
  _no_split_modules = ["HGNetV2BasicLayer"]
47
48
 
49
+ def _init_weights(self, module):
50
+ super()._init_weights(module)
51
+ # We need to check it like that as d_fine models replace the BatchNorm2d by their own
52
+ if "BatchNorm" in module.__class__.__name__:
53
+ init.ones_(module.weight)
54
+ init.zeros_(module.bias)
55
+ init.zeros_(module.running_mean)
56
+ init.ones_(module.running_var)
57
+
48
58
 
49
59
  class HGNetV2LearnableAffineBlock(nn.Module):
50
60
  def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0):
@@ -20,6 +20,7 @@ import torch
20
20
  import torch.nn.functional as F
21
21
  from torch import Tensor, nn
22
22
 
23
+ from ... import initialization as init
23
24
  from ...configuration_utils import PreTrainedConfig
24
25
  from ...modeling_outputs import (
25
26
  BackboneOutput,
@@ -170,6 +171,15 @@ class HGNetV2PreTrainedModel(PreTrainedModel):
170
171
  input_modalities = ("image",)
171
172
  _no_split_modules = ["HGNetV2BasicLayer"]
172
173
 
174
+ def _init_weights(self, module):
175
+ super()._init_weights(module)
176
+ # We need to check it like that as d_fine models replace the BatchNorm2d by their own
177
+ if "BatchNorm" in module.__class__.__name__:
178
+ init.ones_(module.weight)
179
+ init.zeros_(module.bias)
180
+ init.zeros_(module.running_mean)
181
+ init.ones_(module.running_var)
182
+
173
183
 
174
184
  class HGNetV2LearnableAffineBlock(nn.Module):
175
185
  def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0):
@@ -648,6 +648,10 @@ class HubertPreTrainedModel(PreTrainedModel):
648
648
  elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
649
649
  init.zeros_(module.bias)
650
650
  init.ones_(module.weight)
651
+ if getattr(module, "running_mean", None) is not None:
652
+ init.zeros_(module.running_mean)
653
+ init.ones_(module.running_var)
654
+ init.zeros_(module.num_batches_tracked)
651
655
  elif isinstance(module, nn.Conv1d):
652
656
  if is_deepspeed_zero3_enabled():
653
657
  import deepspeed
@@ -145,6 +145,10 @@ class HubertPreTrainedModel(PreTrainedModel):
145
145
  elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
146
146
  init.zeros_(module.bias)
147
147
  init.ones_(module.weight)
148
+ if getattr(module, "running_mean", None) is not None:
149
+ init.zeros_(module.running_mean)
150
+ init.ones_(module.running_var)
151
+ init.zeros_(module.num_batches_tracked)
148
152
  elif isinstance(module, nn.Conv1d):
149
153
  if is_deepspeed_zero3_enabled():
150
154
  import deepspeed
@@ -320,7 +320,7 @@ class HunYuanDenseV1RotaryEmbedding(nn.Module):
320
320
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
321
321
 
322
322
  self.register_buffer("inv_freq", inv_freq, persistent=False)
323
- self.original_inv_freq = inv_freq
323
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
324
324
 
325
325
  @staticmethod
326
326
  def compute_default_rope_parameters(
@@ -148,7 +148,7 @@ class HunYuanDenseV1RotaryEmbedding(LlamaRotaryEmbedding):
148
148
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
149
149
 
150
150
  self.register_buffer("inv_freq", inv_freq, persistent=False)
151
- self.original_inv_freq = inv_freq
151
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
152
152
 
153
153
 
154
154
  class HunYuanDenseV1Model(LlamaModel):
@@ -6,7 +6,7 @@ from ...utils.import_utils import define_import_structure
6
6
 
7
7
  if TYPE_CHECKING:
8
8
  from .configuration_hunyuan_v1_moe import *
9
- from .modeling_hunyuan import *
9
+ from .modeling_hunyuan_v1_moe import *
10
10
  else:
11
11
  import sys
12
12
 
@@ -30,14 +30,19 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
33
+ from ...integrations import (
34
+ use_experts_implementation,
35
+ use_kernel_forward_from_hub,
36
+ use_kernel_func_from_hub,
37
+ use_kernelized_func,
38
+ )
34
39
  from ...masking_utils import create_causal_mask
35
40
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
36
41
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
42
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
43
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
44
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
45
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
46
  from ...utils.generic import check_model_inputs, maybe_autocast
42
47
  from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config
43
48
 
@@ -244,6 +249,7 @@ class HunYuanMoEV1Gate(nn.Module):
244
249
  return logits
245
250
 
246
251
 
252
+ @use_experts_implementation
247
253
  class HunYuanMoEV1Experts(nn.Module):
248
254
  """Collection of expert weights stored as 3D tensors."""
249
255
 
@@ -371,7 +377,9 @@ class HunYuanMoEV1PreTrainedModel(PreTrainedModel):
371
377
  _supports_flash_attn = True
372
378
  _supports_sdpa = True
373
379
  _supports_flex_attn = True
374
- _can_compile_fullgraph = False
380
+ _can_compile_fullgraph = (
381
+ is_grouped_mm_available()
382
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
375
383
  _supports_attention_backend = True
376
384
  _can_record_outputs = {
377
385
  "hidden_states": HunYuanMoEV1DecoderLayer,
@@ -413,7 +421,7 @@ class HunYuanMoEV1RotaryEmbedding(nn.Module):
413
421
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
414
422
 
415
423
  self.register_buffer("inv_freq", inv_freq, persistent=False)
416
- self.original_inv_freq = inv_freq
424
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
417
425
 
418
426
  @staticmethod
419
427
  def compute_default_rope_parameters(