transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -63,6 +63,24 @@ class TensorProcessor:
63
63
  def __init__(self, config=None):
64
64
  self.config = config or {}
65
65
 
66
+ def preprocess_name(self, hf_name: str) -> str:
67
+ """
68
+ Preprocesses the tensor name to ease loading the GGUF tensors.
69
+ """
70
+ return hf_name
71
+
72
+ def perform_fallback_tensor_mapping(
73
+ self, gguf_to_hf_name_map: dict[str, str], suffix: str, qual_name: str, hf_name: str
74
+ ):
75
+ """
76
+ Called when get_gguf_hf_weights_map fails to map a HF parameter
77
+ (tensor) and corresponding GGUF one.
78
+
79
+ This is particularly useful to resolve one-to-many
80
+ HF-GGUF mappings sometimes appear in some MoE models.
81
+ """
82
+ pass
83
+
66
84
  def process(self, weights, name, **kwargs):
67
85
  return GGUFTensor(weights, name, {})
68
86
 
@@ -98,15 +116,31 @@ class LlamaTensorProcessor(TensorProcessor):
98
116
 
99
117
 
100
118
  class Qwen2MoeTensorProcessor(TensorProcessor):
119
+ HF_EXPERT_RENAME_PATTERN = re.compile(r"mlp.experts.\d+.")
120
+ HF_MOE_W13_PATTERN = re.compile(r"model\.layers\.(?P<bid>\d+)\.mlp\.experts\.gate_up_proj")
121
+ GGUF_MOE_WEIGHTS_PATTERN = re.compile(r"(?P<name>.*\.ffn_(?P<w>gate|down|up)_exps)\.weight$")
122
+
101
123
  def __init__(self, config=None):
102
124
  super().__init__(config=config)
103
125
 
104
- def process(self, weights, name, **kwargs):
105
- if "_exp" in name:
126
+ def preprocess_name(self, hf_name: str) -> str:
127
+ return re.sub(self.HF_EXPERT_RENAME_PATTERN, "mlp.experts.", hf_name)
128
+
129
+ def perform_fallback_tensor_mapping(
130
+ self, gguf_to_hf_name_map: dict[str, str], suffix: str, qual_name: str, hf_name: str
131
+ ):
132
+ # Map merged MoE weights (w1 (gate) and w3 (up)) separately.
133
+ if m := re.fullmatch(self.HF_MOE_W13_PATTERN, hf_name):
134
+ full_hf_name = qual_name + hf_name
135
+ gguf_to_hf_name_map[f"blk.{m['bid']}.ffn_gate_exps{suffix}"] = full_hf_name
136
+ gguf_to_hf_name_map[f"blk.{m['bid']}.ffn_up_exps{suffix}"] = full_hf_name
137
+
138
+ def process(self, weights, name: str, **kwargs):
139
+ if m := re.fullmatch(self.GGUF_MOE_WEIGHTS_PATTERN, name):
106
140
  tensor_key_mapping = kwargs.get("tensor_key_mapping")
107
141
  parsed_parameters = kwargs.get("parsed_parameters")
108
142
  if tensor_key_mapping:
109
- self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
143
+ self._set_moe_expert_tensor(weights, parsed_parameters, tensor_key_mapping[m["name"]], m["w"])
110
144
  return GGUFTensor(weights, None, {})
111
145
  if "ffn_gate_inp_shexp" in name:
112
146
  # for compatibility tensor shared_expert_gate must be (1, 2048) dim,
@@ -114,17 +148,27 @@ class Qwen2MoeTensorProcessor(TensorProcessor):
114
148
  weights = np.expand_dims(weights, axis=0)
115
149
  return GGUFTensor(weights, name, {})
116
150
 
117
- def _split_moe_expert_tensor(
118
- self, weights: np.ndarray, parsed_parameters: dict[str, dict], name: str, tensor_key_mapping: dict
119
- ):
120
- # Original merge implementation
121
- # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
122
- name = tensor_key_mapping[name]
123
- w_counter = self.config.get("num_experts", 60)
124
- for i in range(0, w_counter):
125
- temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.")
126
- exp_weight = weights[i]
127
- parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
151
+ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[str, dict], hf_name: str, w: str):
152
+ torch_weights = torch.from_numpy(np.copy(weights))
153
+ if w == "down":
154
+ parsed_parameters["tensors"][hf_name] = torch_weights
155
+ else:
156
+ # Double the size of the second dimension to interleave w1 (gate) and w3 (up)
157
+ # weights per expert (which is the first dimension).
158
+ # w1 (gate) comes first and w3 (up) comes second.
159
+ # ref: https://github.com/vllm-project/vllm/blob/8f8fda261a620234fdeea338f44093d5d8072879/vllm/model_executor/layers/fused_moe/layer.py#L988-L1015
160
+ shape = list(weights.shape)
161
+ shard_dim = 1
162
+ shard_size = shape[shard_dim]
163
+ shape[shard_dim] = shard_size * 2
164
+ if hf_name not in parsed_parameters["tensors"]:
165
+ parsed_parameters["tensors"][hf_name] = torch.zeros(shape, dtype=torch_weights.dtype)
166
+ out: torch.Tensor = parsed_parameters["tensors"][hf_name]
167
+ if w == "gate":
168
+ out = out.narrow(shard_dim, 0, shard_size)
169
+ else: # w == "up"
170
+ out = out.narrow(shard_dim, shard_size, shard_size)
171
+ out.copy_(torch_weights)
128
172
 
129
173
 
130
174
  class BloomTensorProcessor(TensorProcessor):
@@ -281,6 +325,7 @@ def read_field(reader, field):
281
325
  # modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147
282
326
  def get_gguf_hf_weights_map(
283
327
  hf_model,
328
+ processor: TensorProcessor,
284
329
  model_type: Optional[str] = None,
285
330
  num_layers: Optional[int] = None,
286
331
  qual_name: str = "",
@@ -334,9 +379,7 @@ def get_gguf_hf_weights_map(
334
379
  gguf_to_hf_name_map = {}
335
380
  state_dict = hf_model.state_dict()
336
381
  for hf_name in state_dict:
337
- # An exception for qwen2moe/qwen3moe model, where the expert layers are packed
338
- if model_type in ("qwen2moe", "qwen3moe") and "mlp.experts." in hf_name:
339
- hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name)
382
+ hf_name = processor.preprocess_name(hf_name)
340
383
 
341
384
  name, suffix = hf_name, ""
342
385
  if hf_name.endswith(".weight") or hf_name.endswith(".bias"):
@@ -345,6 +388,7 @@ def get_gguf_hf_weights_map(
345
388
 
346
389
  gguf_name = name_map.get_name(name)
347
390
  if gguf_name is None:
391
+ processor.perform_fallback_tensor_mapping(gguf_to_hf_name_map, suffix, qual_name, hf_name)
348
392
  continue
349
393
 
350
394
  gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name
@@ -353,7 +397,9 @@ def get_gguf_hf_weights_map(
353
397
  # Therefore, we need to check submodule as well to get a correct mapping
354
398
  if named_children := hf_model.named_children():
355
399
  for name, child in named_children:
356
- sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.")
400
+ sub_map = get_gguf_hf_weights_map(
401
+ child, processor, model_type, num_layers, qual_name=f"{qual_name}{name}."
402
+ )
357
403
  # Ignore the keys that are already in the main map to avoid overwriting
358
404
  sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map}
359
405
  gguf_to_hf_name_map.update(sub_map)
@@ -507,12 +553,13 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
507
553
  if return_tensors:
508
554
  parsed_parameters["tensors"] = {}
509
555
 
510
- tensor_key_mapping = get_gguf_hf_weights_map(model_to_load)
511
556
  config = parsed_parameters.get("config", {})
512
557
 
513
558
  ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
514
559
  processor = ProcessorClass(config=config)
515
560
 
561
+ tensor_key_mapping = get_gguf_hf_weights_map(model_to_load, processor)
562
+
516
563
  for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
517
564
  name = tensor.name
518
565
  weights = dequantize(tensor.data, tensor.tensor_type)
@@ -46,17 +46,19 @@ def dynamic_rope_update(rope_forward):
46
46
  def longrope_frequency_update(self, position_ids, device, layer_type=None):
47
47
  """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
48
48
  seq_len = torch.max(position_ids) + 1
49
- original_max_position_embeddings = getattr(
50
- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
51
- )
49
+
52
50
  if layer_type is None:
53
51
  rope_type = self.rope_type
54
52
  original_inv_freq = self.original_inv_freq
55
53
  prefix = ""
54
+ original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"]
56
55
  else:
57
56
  rope_type = self.rope_type[layer_type]
58
57
  original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
59
58
  prefix = f"{layer_type}_"
59
+ original_max_position_embeddings = self.config.rope_parameters[layer_type][
60
+ "original_max_position_embeddings"
61
+ ]
60
62
 
61
63
  if seq_len > original_max_position_embeddings:
62
64
  if not hasattr(self, f"{layer_type}_long_inv_freq"):
@@ -223,7 +225,6 @@ def _compute_dynamic_ntk_parameters(
223
225
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
224
226
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
225
227
  """
226
- # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters
227
228
  # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
228
229
  config.standardize_rope_params()
229
230
  rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
@@ -232,30 +233,29 @@ def _compute_dynamic_ntk_parameters(
232
233
  partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
233
234
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
234
235
  dim = int(head_dim * partial_rotary_factor)
235
- max_position_embeddings = config.max_position_embeddings
236
236
  factor = rope_parameters_dict["factor"]
237
237
  attention_factor = 1.0 # Unused in this type of RoPE
238
238
 
239
239
  # seq_len: default to max_position_embeddings, e.g. at init time
240
240
  if seq_len is None:
241
- seq_len = max_position_embeddings
241
+ seq_len = config.max_position_embeddings
242
242
  elif isinstance(seq_len, torch.Tensor):
243
243
  seq_len = torch.maximum(
244
244
  seq_len,
245
- torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
245
+ torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
246
246
  )
247
247
  else:
248
- seq_len = max(seq_len, max_position_embeddings)
248
+ seq_len = max(seq_len, config.max_position_embeddings)
249
249
 
250
250
  # Compute the inverse frequencies
251
- base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
251
+ base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
252
252
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
253
253
  return inv_freq, attention_factor
254
254
 
255
255
 
256
256
  def _compute_yarn_parameters(
257
257
  config: "PreTrainedConfig",
258
- device: "torch.device",
258
+ device: Optional["torch.device"] = None,
259
259
  seq_len: Optional[int] = None,
260
260
  layer_type: Optional[str] = None,
261
261
  ) -> tuple["torch.Tensor", float]:
@@ -292,8 +292,7 @@ def _compute_yarn_parameters(
292
292
  `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing
293
293
  the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor`
294
294
  will be calculated based on `factor` only.
295
- * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used
296
- during pretraining. If not provided, the function falls back to `max_position_embeddings`.
295
+ * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining.
297
296
  * `truncate` (`bool`, *optional*): Whether to truncate the correction range.
298
297
 
299
298
  Additionally, this function will make use of the following properties if they are found in the config:
@@ -324,15 +323,13 @@ def _compute_yarn_parameters(
324
323
  attention_factor = rope_parameters_dict.get("attention_factor")
325
324
  mscale = rope_parameters_dict.get("mscale")
326
325
  mscale_all_dim = rope_parameters_dict.get("mscale_all_dim")
326
+ original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
327
327
 
328
- # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
329
- # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
330
- # values to compute the default attention scaling factor, instead of using `factor`.
331
- if "original_max_position_embeddings" in rope_parameters_dict:
332
- original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
328
+ # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field
329
+ # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value
330
+ # to compute the default attention scaling factor, instead of using `factor`.
331
+ if factor is None:
333
332
  factor = config.max_position_embeddings / original_max_position_embeddings
334
- else:
335
- original_max_position_embeddings = config.max_position_embeddings
336
333
 
337
334
  def get_mscale(scale, mscale=1):
338
335
  if scale <= 1:
@@ -393,7 +390,7 @@ def _compute_yarn_parameters(
393
390
 
394
391
  def _compute_longrope_parameters(
395
392
  config: "PreTrainedConfig",
396
- device: "torch.device",
393
+ device: Optional["torch.device"] = None,
397
394
  seq_len: Optional[int] = None,
398
395
  layer_type: Optional[str] = None,
399
396
  ) -> tuple["torch.Tensor", float]:
@@ -440,7 +437,6 @@ def _compute_longrope_parameters(
440
437
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
441
438
  post-processing scaling factor applied to the computed cos/sin.
442
439
  """
443
- # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters
444
440
  # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
445
441
  config.standardize_rope_params()
446
442
  rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
@@ -454,14 +450,13 @@ def _compute_longrope_parameters(
454
450
  short_factor = rope_parameters_dict["short_factor"]
455
451
  factor = rope_parameters_dict.get("factor")
456
452
  attention_factor = rope_parameters_dict.get("attention_factor")
453
+ original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
457
454
 
458
455
  # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
459
456
  # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
460
457
  # values to compute the default attention scaling factor, instead of using `factor`.
461
- if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None):
458
+ if factor is None:
462
459
  factor = config.max_position_embeddings / original_max_position_embeddings
463
- else:
464
- original_max_position_embeddings = config.max_position_embeddings
465
460
 
466
461
  # Sets the attention factor as suggested in the paper
467
462
  if attention_factor is None:
@@ -483,7 +478,7 @@ def _compute_longrope_parameters(
483
478
 
484
479
  def _compute_llama3_parameters(
485
480
  config: "PreTrainedConfig",
486
- device: "torch.device",
481
+ device: Optional["torch.device"] = None,
487
482
  seq_len: Optional[int] = None,
488
483
  layer_type: Optional[str] = None,
489
484
  ) -> tuple["torch.Tensor", float]:
@@ -587,7 +582,7 @@ class RopeParameters(TypedDict, total=False):
587
582
  most scaling types, a `factor` of x will enable the model to handle sequences of length x *
588
583
  original maximum pre-trained length.
589
584
  original_max_position_embeddings (`int`, *optional*):
590
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
585
+ Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during
591
586
  pretraining.
592
587
  attention_factor (`float`, *optional*):
593
588
  Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
@@ -641,6 +636,7 @@ class RotaryEmbeddingConfigMixin:
641
636
 
642
637
  # Standardize and validate the correctness of rotary position embeddings parameters
643
638
  self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta))
639
+
644
640
  if "partial_rotary_factor" in kwargs:
645
641
  self.rope_parameters.setdefault("partial_rotary_factor", kwargs["partial_rotary_factor"])
646
642
  ignore_keys_at_rope_validation = {"partial_rotary_factor"}
@@ -671,14 +667,30 @@ class RotaryEmbeddingConfigMixin:
671
667
  rope_parameters.setdefault("rope_theta", rope_theta)
672
668
  if partial_rotary_factor is not None:
673
669
  rope_parameters["partial_rotary_factor"] = partial_rotary_factor
670
+
671
+ # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling
672
+ if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]:
673
+ if hasattr(self, "original_max_position_embeddings"):
674
+ # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field
675
+ # containing the pretrained value outside rope parameters. This is an exception case where we
676
+ # give priority to `self.original_max_position_embeddings
677
+ self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings
678
+ else:
679
+ self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings)
680
+
674
681
  # Case 2: different RoPE for each layer -> several params as nested dict
675
682
  else:
676
- for layer_type in layer_types:
683
+ for layer_type in set(layer_types):
677
684
  rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default"))
678
685
  rope_parameters[layer_type].setdefault("rope_theta", rope_theta)
679
686
  if partial_rotary_factor is not None:
680
687
  rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor
681
688
 
689
+ if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]:
690
+ self.rope_parameters[layer_type].setdefault(
691
+ "original_max_position_embeddings", self.max_position_embeddings
692
+ )
693
+
682
694
  self.rope_parameters = rope_parameters
683
695
 
684
696
  def validate_rope(self: "PreTrainedConfig", ignore_keys: Optional[set] = None):
@@ -725,26 +737,24 @@ class RotaryEmbeddingConfigMixin:
725
737
  logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
726
738
 
727
739
  def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
728
- # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
729
- optional_keys = {"original_max_position_embeddings"}
730
740
  required_keys = {"rope_type", "factor"}
731
741
  received_keys = set(rope_parameters.keys())
732
742
  rope_type = rope_parameters["rope_type"]
733
- self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
743
+ self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
734
744
 
735
745
  factor = rope_parameters["factor"]
736
746
  if factor is None or not isinstance(factor, float) or factor < 1.0:
737
747
  logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
738
748
 
739
749
  def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
740
- required_keys = {"rope_type", "factor", "rope_theta"}
750
+ required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"}
741
751
  optional_keys = {
742
752
  "attention_factor",
743
753
  "beta_fast",
744
754
  "beta_slow",
745
- "original_max_position_embeddings",
746
755
  "mscale",
747
756
  "mscale_all_dim",
757
+ "truncate",
748
758
  }
749
759
  received_keys = set(rope_parameters.keys())
750
760
  rope_type = rope_parameters["rope_type"]
@@ -772,37 +782,24 @@ class RotaryEmbeddingConfigMixin:
772
782
  f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
773
783
  )
774
784
 
775
- # Models should set `config.rope_parameters["original_max_position_embeddings"]` to their original (pre-yarn) context
776
- # length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
777
- # However, for BC purposes, we allow the former to be unset.
778
- original_max_position_embeddings = self.rope_parameters.get("original_max_position_embeddings")
779
- if original_max_position_embeddings is not None:
780
- # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
781
- implicit_factor = self.max_position_embeddings / original_max_position_embeddings
782
- if implicit_factor != factor:
783
- logger.warning_once(
784
- f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match "
785
- "the ratio implicitly set by other parameters (implicit factor = "
786
- "post-yarn context length / pre-yarn context length = "
787
- "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = "
788
- f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
789
- "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
790
- )
791
- # No `config.rope_parameters["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
792
- # pre-yarn or the post-yarn context length?
793
- # BC: we assume it is the pre-yarn context length.
794
- else:
785
+ # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
786
+ # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was
787
+ # inferred from `max_position_embeddings` during standardization
788
+ original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"]
789
+ implicit_factor = self.max_position_embeddings / original_max_position_embeddings
790
+ if implicit_factor != factor and implicit_factor != 1:
795
791
  logger.warning_once(
796
- "config.rope_parameters['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
797
- "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
798
- "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
799
- "factor) -- we recommend updating both fields for optimal downstream model usage."
792
+ f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match "
793
+ "the ratio implicitly set by other parameters (implicit factor = "
794
+ "post-yarn context length / pre-yarn context length = "
795
+ "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = "
796
+ f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
797
+ "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config."
800
798
  )
801
799
 
802
800
  def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
803
- required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta"}
804
- # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
805
- optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
801
+ required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"}
802
+ optional_keys = {"attention_factor", "factor"}
806
803
  received_keys = set(rope_parameters.keys())
807
804
  rope_type = rope_parameters["rope_type"]
808
805
  self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
@@ -827,29 +824,28 @@ class RotaryEmbeddingConfigMixin:
827
824
  f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
828
825
  )
829
826
 
830
- # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
831
- # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_parameters` and is
832
- # unique to longrope (= undesirable)
833
- if hasattr(self, "original_max_position_embeddings"):
827
+ factor = rope_parameters.get("factor")
828
+ original_max_position_embeddings = rope_parameters["original_max_position_embeddings"]
829
+
830
+ # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over
831
+ # `original_max_position_embeddings` to compute internal variables. The latter is undesirable
832
+ if factor is None and original_max_position_embeddings is not None:
834
833
  logger.warning_once(
835
- "This model has set a `original_max_position_embeddings` field, to be used together with "
834
+ "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with "
836
835
  "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`"
837
836
  "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
838
837
  "as it is compatible with most model architectures."
839
838
  )
840
- else:
841
- factor = rope_parameters.get("factor")
842
- if factor is None:
843
- logger.warning("Missing required keys in `rope_parameters`: 'factor'")
844
- elif not isinstance(factor, float) or factor < 1.0:
845
- logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
846
-
847
- attention_factor = rope_parameters.get("attention_factor")
848
- if attention_factor is not None:
849
- if not isinstance(attention_factor, float) or attention_factor < 0.0:
850
- logger.warning(
851
- f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}"
852
- )
839
+ elif factor is None and original_max_position_embeddings is None:
840
+ logger.warning("Missing required keys in `rope_parameters`: 'factor'")
841
+ elif not isinstance(factor, float) or factor < 1.0:
842
+ logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
843
+
844
+ attention_factor = rope_parameters.get("attention_factor")
845
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0.0):
846
+ logger.warning(
847
+ f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}"
848
+ )
853
849
 
854
850
  def _validate_llama3_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
855
851
  required_keys = {
@@ -906,6 +902,10 @@ class RotaryEmbeddingConfigMixin:
906
902
  received_keys -= {"type"}
907
903
  required_keys.add("rope_type")
908
904
 
905
+ optional_keys = optional_keys or set()
906
+ if "partial_rotary_factor" not in optional_keys:
907
+ optional_keys.add("partial_rotary_factor")
908
+
909
909
  # Some models need to store model-specific keys, and we don't want to throw warning at them
910
910
  if ignore_keys is not None:
911
911
  received_keys -= ignore_keys
@@ -914,10 +914,7 @@ class RotaryEmbeddingConfigMixin:
914
914
  if missing_keys:
915
915
  raise KeyError(f"Missing required keys in `rope_parameters` for 'rope_type'='{rope_type}': {missing_keys}")
916
916
 
917
- if optional_keys is not None:
918
- unused_keys = received_keys - required_keys - optional_keys
919
- else:
920
- unused_keys = received_keys - required_keys
917
+ unused_keys = received_keys - required_keys - optional_keys
921
918
  if unused_keys:
922
919
  logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}")
923
920