transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ import inspect
21
21
  import os
22
22
  import re
23
23
  from collections import OrderedDict, defaultdict
24
- from contextlib import contextmanager
25
24
  from typing import TYPE_CHECKING
26
25
 
27
26
  from safetensors import safe_open
@@ -55,114 +54,6 @@ if TYPE_CHECKING:
55
54
  logger = logging.get_logger(__name__)
56
55
 
57
56
 
58
- @contextmanager
59
- def init_empty_weights(include_buffers: bool = False):
60
- """
61
- A context manager under which models are initialized with all parameters on the meta device, therefore creating an
62
- empty model. Useful when just initializing the model would blow the available RAM.
63
-
64
- Args:
65
- include_buffers (`bool`, *optional*):
66
- Whether or not to also put all buffers on the meta device while initializing.
67
-
68
- Example:
69
-
70
- ```python
71
- import torch.nn as nn
72
- from accelerate import init_empty_weights
73
-
74
- # Initialize a model with 100 billions parameters in no time and without using any RAM.
75
- with init_empty_weights():
76
- tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
77
- ```
78
-
79
- <Tip warning={true}>
80
-
81
- Any model created under this context manager has no weights. As such you can't do something like
82
- `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
83
- Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
84
- called.
85
-
86
- </Tip>
87
- """
88
- with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
89
- yield f
90
-
91
-
92
- @contextmanager
93
- def init_on_device(device: "torch.device", include_buffers: bool = False):
94
- """
95
- A context manager under which models are initialized with all parameters on the specified device.
96
-
97
- Args:
98
- device (`torch.device`):
99
- Device to initialize all parameters on.
100
- include_buffers (`bool`, *optional*):
101
- Whether or not to also put all buffers on the meta device while initializing.
102
-
103
- Example:
104
-
105
- ```python
106
- import torch.nn as nn
107
- from accelerate import init_on_device
108
-
109
- with init_on_device(device=torch.device("cuda")):
110
- tst = nn.Linear(100, 100) # on `cuda` device
111
- ```
112
- """
113
- if include_buffers:
114
- with device:
115
- yield
116
- return
117
-
118
- old_register_parameter = nn.Module.register_parameter
119
- if include_buffers:
120
- old_register_buffer = nn.Module.register_buffer
121
-
122
- def register_empty_parameter(module, name, param):
123
- old_register_parameter(module, name, param)
124
- if param is not None:
125
- param_cls = type(module._parameters[name])
126
- kwargs = module._parameters[name].__dict__
127
- kwargs["requires_grad"] = param.requires_grad
128
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
129
-
130
- def register_empty_buffer(module, name, buffer, persistent=True):
131
- old_register_buffer(module, name, buffer, persistent=persistent)
132
- if buffer is not None:
133
- module._buffers[name] = module._buffers[name].to(device)
134
-
135
- # Patch tensor creation
136
- if include_buffers:
137
- tensor_constructors_to_patch = {
138
- torch_function_name: getattr(torch, torch_function_name)
139
- for torch_function_name in ["empty", "zeros", "ones", "full"]
140
- }
141
- else:
142
- tensor_constructors_to_patch = {}
143
-
144
- def patch_tensor_constructor(fn):
145
- def wrapper(*args, **kwargs):
146
- kwargs["device"] = device
147
- return fn(*args, **kwargs)
148
-
149
- return wrapper
150
-
151
- try:
152
- nn.Module.register_parameter = register_empty_parameter
153
- if include_buffers:
154
- nn.Module.register_buffer = register_empty_buffer
155
- for torch_function_name in tensor_constructors_to_patch:
156
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
157
- yield
158
- finally:
159
- nn.Module.register_parameter = old_register_parameter
160
- if include_buffers:
161
- nn.Module.register_buffer = old_register_buffer
162
- for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
163
- setattr(torch, torch_function_name, old_torch_function)
164
-
165
-
166
57
  def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
167
58
  from ..modeling_utils import get_torch_context_manager_or_global_device
168
59
 
@@ -182,6 +73,10 @@ def check_and_set_device_map(device_map: "torch.device | int | str | dict | None
182
73
  device_map = {"": device_map}
183
74
  elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
184
75
  try:
76
+ if device_map == "cuda":
77
+ # setting to the local rank
78
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
79
+ device_map = f"cuda:{local_rank}"
185
80
  device_map = {"": torch.device(device_map)}
186
81
  except RuntimeError:
187
82
  raise ValueError(
@@ -398,7 +293,7 @@ def _get_device_map(
398
293
  # especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
399
294
  # the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
400
295
  # if we were in-between, as otherwise we blow-up cpu memory
401
- if max_memory is None:
296
+ if max_memory is None and "cpu" in inferred_max_memory:
402
297
  inferred_max_memory["cpu"] *= 0.90
403
298
 
404
299
  if hf_quantizer is not None:
@@ -458,10 +353,13 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload
458
353
  dispatch_model(model, **device_map_kwargs)
459
354
 
460
355
 
461
- def expand_device_map(device_map, param_names):
356
+ def expand_device_map(device_map: dict | None, param_names: list[str]):
462
357
  """
463
358
  Expand a device map to return the correspondence parameter name to device.
464
359
  """
360
+ if device_map is None:
361
+ return dict.fromkeys(param_names, "cpu")
362
+
465
363
  # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
466
364
  device_map_regex = re.compile(
467
365
  "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
@@ -474,6 +372,15 @@ def expand_device_map(device_map, param_names):
474
372
  return new_device_map
475
373
 
476
374
 
375
+ def get_device(device_map: dict | None, param_name: str, valid_torch_device: bool = False) -> torch.device | str | int:
376
+ """Return the device on which `param_name` should be according to the `device_map`. If `valid_torch_device` is `True`,
377
+ then if the device is `"disk"`, `"cpu"` will be returned instead."""
378
+ device = expand_device_map(device_map, [param_name])[param_name]
379
+ if valid_torch_device and device == "disk":
380
+ return "cpu"
381
+ return device
382
+
383
+
477
384
  def accelerate_disk_offload(
478
385
  model: "PreTrainedModel",
479
386
  disk_offload_folder: str | None,
@@ -554,6 +461,32 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str |
554
461
  return offload_index
555
462
 
556
463
 
464
+ def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
465
+ """Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
466
+ inside `model`.
467
+ This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
468
+ then resave them to disk in the correct shard...)."""
469
+ # Start from the most inner module, and try to find the hook that was used for offloading the param
470
+ module_parts = param_name.split(".")
471
+ modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
472
+ for parent_name in modules_to_check:
473
+ parent = model.get_submodule(parent_name)
474
+ if hasattr(parent, "_hf_hook"):
475
+ weights_map = parent._hf_hook.weights_map
476
+ truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
477
+ break
478
+ # If we did not break the loop, something is wrong
479
+ else:
480
+ raise ValueError(
481
+ f"{param_name} is on the meta device because it was offloaded, but we could not find "
482
+ "the corresponding hook for it"
483
+ )
484
+
485
+ # This call loads it from disk
486
+ tensor = weights_map[truncated_param_name]
487
+ return tensor
488
+
489
+
557
490
  def _init_infer_auto_device_map(
558
491
  model: nn.Module,
559
492
  max_memory: dict[int | str, int | str] | None = None,
@@ -14,13 +14,11 @@
14
14
  "AQLM (Additive Quantization of Language Model) integration file"
15
15
 
16
16
  from ..quantizers.quantizers_utils import should_convert_module
17
- from ..utils import is_accelerate_available, is_torch_available, logging
17
+ from ..utils import is_torch_available, logging
18
18
 
19
19
 
20
- if is_accelerate_available():
21
- from accelerate import init_empty_weights
22
-
23
20
  if is_torch_available():
21
+ import torch
24
22
  import torch.nn as nn
25
23
 
26
24
  logger = logging.get_logger(__name__)
@@ -46,7 +44,7 @@ def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = N
46
44
  for module_name, module in model.named_modules():
47
45
  if not should_convert_module(module_name, modules_to_not_convert):
48
46
  continue
49
- with init_empty_weights():
47
+ with torch.device("meta"):
50
48
  if isinstance(module, nn.Linear):
51
49
  new_module = QuantizedLinear(
52
50
  module.in_features,
@@ -16,12 +16,9 @@
16
16
  from typing import Optional, Union
17
17
 
18
18
  from ..quantizers.quantizers_utils import should_convert_module
19
- from ..utils import is_accelerate_available, is_torch_available, logging
19
+ from ..utils import is_torch_available, logging
20
20
 
21
21
 
22
- if is_accelerate_available():
23
- from accelerate import init_empty_weights
24
-
25
22
  if is_torch_available():
26
23
  import torch
27
24
  import torch.nn as nn
@@ -97,7 +94,7 @@ def replace_with_awq_linear(
97
94
  for module_name, module in model.named_modules():
98
95
  if not should_convert_module(module_name, modules_to_not_convert):
99
96
  continue
100
- with init_empty_weights():
97
+ with torch.device("meta"):
101
98
  if isinstance(module, nn.Linear):
102
99
  new_module = target_cls(
103
100
  bits=quantization_config.bits,
@@ -1,10 +1,7 @@
1
1
  from ..quantizers.quantizers_utils import should_convert_module
2
- from ..utils import is_accelerate_available, is_torch_available, logging
2
+ from ..utils import is_torch_available, logging
3
3
 
4
4
 
5
- if is_accelerate_available():
6
- from accelerate import init_empty_weights
7
-
8
5
  if is_torch_available():
9
6
  import torch
10
7
  import torch.nn as nn
@@ -92,7 +89,7 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
92
89
 
93
90
  Explanation of the example:
94
91
  ---------------------------
95
- Let's take the first value for example 0b10100001, we we will only focus on the first column,
92
+ Let's take the first value for example 0b10100001, we will only focus on the first column,
96
93
  because every element is unpacked across the first dimension
97
94
  - First 2 bits: `01` → 0 at [0][0]
98
95
  - Second 2 bits: `00` → -1 at [0][2]
@@ -173,7 +170,7 @@ class BitLinear(nn.Module):
173
170
  Activation function : Performs symmetric, per-token quantization on the input activations.
174
171
  Parameters:
175
172
  -----------
176
- x : torch.Tensor
173
+ input : torch.Tensor
177
174
  Input activations to be quantized.
178
175
  num_bits : int, optional (default=8)
179
176
  Number of bits to use for quantization, determining the quantization range.
@@ -334,7 +331,7 @@ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None =
334
331
  for module_name, module in model.named_modules():
335
332
  if not should_convert_module(module_name, modules_to_not_convert):
336
333
  continue
337
- with init_empty_weights():
334
+ with torch.device("meta"):
338
335
  if isinstance(module, nn.Linear):
339
336
  if quantization_config and quantization_config.linear_class == "autobitlinear":
340
337
  new_module = AutoBitLinear(
@@ -365,7 +362,7 @@ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None =
365
362
 
366
363
  if not has_been_replaced:
367
364
  logger.warning(
368
- "You are loading your model using eetq but no linear modules were found in your model."
365
+ "You are loading your model using bitnet but no linear modules were found in your model."
369
366
  " Please double check your model architecture, or submit an issue on github if you think this is"
370
367
  " a bug."
371
368
  )
@@ -22,7 +22,6 @@ if is_torch_available():
22
22
 
23
23
  if is_accelerate_available():
24
24
  import accelerate
25
- from accelerate import init_empty_weights
26
25
  from accelerate.hooks import add_hook_to_module, remove_hook_from_module
27
26
 
28
27
  logger = logging.get_logger(__name__)
@@ -181,7 +180,7 @@ def replace_with_bnb_linear(
181
180
  if not should_convert_module(module_name, modules_to_not_convert):
182
181
  continue
183
182
  new_module = None
184
- with init_empty_weights():
183
+ with torch.device("meta"):
185
184
  if isinstance(module, (nn.Linear, Conv1D)):
186
185
  if isinstance(module, Conv1D):
187
186
  in_features, out_features = module.weight.shape
@@ -233,7 +232,7 @@ def replace_with_bnb_linear(
233
232
 
234
233
 
235
234
  # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
236
- def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
235
+ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
237
236
  """
238
237
  Helper function to dequantize 4bit or 8bit bnb weights.
239
238
 
@@ -248,10 +247,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
248
247
 
249
248
  if cls_name == "Params4bit":
250
249
  output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
251
- logger.warning_once(
252
- f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
253
- )
254
- return output_tensor.to(dtype)
250
+ return output_tensor
255
251
 
256
252
  if state.SCB is None:
257
253
  state.SCB = weight.SCB
@@ -263,7 +259,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
263
259
  # Multiply by (scale/127) to dequantize.
264
260
  dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
265
261
 
266
- return dequantized.to(dtype)
262
+ return dequantized
267
263
 
268
264
 
269
265
  def _create_accelerate_new_hook(old_hook):
@@ -283,10 +279,7 @@ def _create_accelerate_new_hook(old_hook):
283
279
  return new_hook
284
280
 
285
281
 
286
- def dequantize_and_replace(
287
- model,
288
- quantization_config=None,
289
- ):
282
+ def dequantize_and_replace(model, quantization_config=None, dtype=None):
290
283
  """
291
284
  Converts a quantized model into its dequantized original version. The newly converted model will have
292
285
  some performance drop compared to the original model before quantization - use it only for specific usecases
@@ -297,14 +290,22 @@ def dequantize_and_replace(
297
290
  quant_method = quantization_config.quantization_method()
298
291
 
299
292
  target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
300
-
301
293
  for module_name, module in model.named_modules():
302
294
  if isinstance(module, target_cls):
303
- with init_empty_weights():
295
+ with torch.device("meta"):
304
296
  bias = getattr(module, "bias", None)
305
297
  new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
306
298
  state = module.state if quant_method == "llm_int8" else None
307
- new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, model.dtype, state))
299
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
300
+ weight = dequantize_bnb_weight(module.weight, state)
301
+ if dtype is None:
302
+ logger.warning_once(
303
+ f"The modules are dequantized in {weight.dtype}. If you want to change the dtype, please specify `dtype` in `dequantize`. "
304
+ )
305
+ else:
306
+ logger.warning_once(f"The modules are dequantized in {weight.dtype} and casted to {dtype}.")
307
+ weight = weight.to(dtype)
308
+ new_module.weight = torch.nn.Parameter(weight)
308
309
  if bias is not None:
309
310
  new_module.bias = bias
310
311
  if hasattr(module, "_hf_hook"):
@@ -304,6 +304,15 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
304
304
  state_dict._metadata = metadata
305
305
 
306
306
  error_msgs = []
307
+ meta_model_state_dict = model_to_load.state_dict()
308
+ missing_keys = set(meta_model_state_dict.keys())
309
+
310
+ prefix_model = getattr(model_to_load, "base_model_prefix", None)
311
+ # take care of the case where in the checkpoint we don't have the prefix
312
+ state_dict = {
313
+ (f"{prefix_model}.{k}" if meta_model_state_dict.get(f"{prefix_model}.{k}") is not None else k): v
314
+ for k, v in state_dict.items()
315
+ }
307
316
 
308
317
  # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
309
318
  # so we need to apply the function recursively.
@@ -320,7 +329,14 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
320
329
  # In sharded models, each shard has only part of the full state_dict, so only gather
321
330
  # parameters that are in the current state_dict.
322
331
  named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
323
- params_to_gather = [named_parameters[k] for k in named_parameters if k in state_dict]
332
+ params_to_gather = []
333
+ for k in named_parameters:
334
+ if k in state_dict:
335
+ param = named_parameters[k]
336
+ # crutial to not init the weight again
337
+ param._is_hf_initialized = True
338
+ params_to_gather.append(param)
339
+ missing_keys.discard(k)
324
340
 
325
341
  if len(params_to_gather) > 0:
326
342
  # because zero3 puts placeholders in model params, this context
@@ -333,11 +349,10 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
333
349
  for name, child in module._modules.items():
334
350
  if child is not None:
335
351
  load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
336
- child._is_hf_initialized = True
337
352
 
338
353
  load(model_to_load, state_dict, assign_to_params_buffers=False)
339
354
 
340
- return error_msgs
355
+ return error_msgs, missing_keys
341
356
 
342
357
 
343
358
  def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
@@ -14,15 +14,13 @@
14
14
  # limitations under the License.
15
15
  from ..core_model_loading import ConversionOps
16
16
  from ..quantizers.quantizers_utils import should_convert_module
17
- from ..utils import is_accelerate_available, is_torch_available, logging
17
+ from ..utils import is_torch_available, logging
18
18
 
19
19
 
20
20
  if is_torch_available():
21
21
  import torch
22
22
  import torch.nn as nn
23
23
 
24
- if is_accelerate_available():
25
- from accelerate import init_empty_weights
26
24
 
27
25
  logger = logging.get_logger(__name__)
28
26
 
@@ -97,7 +95,7 @@ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = N
97
95
  Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
98
96
  for numerical stability reasons.
99
97
  """
100
- from kernels import get_kernel
98
+ from .hub_kernels import get_kernel
101
99
 
102
100
  global eetq_kernels_hub
103
101
  eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
@@ -108,7 +106,7 @@ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = N
108
106
  for module_name, module in model.named_modules():
109
107
  if not should_convert_module(module_name, modules_to_not_convert):
110
108
  continue
111
- with init_empty_weights():
109
+ with torch.device("meta"):
112
110
  if isinstance(module, nn.Linear):
113
111
  new_module = EetqLinear(
114
112
  module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
@@ -257,7 +257,7 @@ class FbgemmFp8Llama4TextExperts(nn.Module):
257
257
  @lru_cache(maxsize=1)
258
258
  def get_quantize_fp8_per_row():
259
259
  if _is_torch_xpu_available:
260
- from kernels import get_kernel
260
+ from .hub_kernels import get_kernel
261
261
 
262
262
  return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
263
263
  return torch.ops.fbgemm.quantize_fp8_per_row
@@ -15,7 +15,7 @@
15
15
 
16
16
  from ..core_model_loading import ConversionOps
17
17
  from ..quantizers.quantizers_utils import should_convert_module
18
- from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
18
+ from ..utils import is_torch_accelerator_available, is_torch_available, logging
19
19
 
20
20
 
21
21
  if is_torch_available():
@@ -25,23 +25,16 @@ if is_torch_available():
25
25
  import triton.language as tl
26
26
  from torch.nn import functional as F
27
27
 
28
- if is_accelerate_available():
29
- from accelerate import init_empty_weights
30
-
31
28
 
32
29
  logger = logging.get_logger(__name__)
33
30
  try:
34
31
  _FP8_DTYPE = torch.float8_e4m3fn
35
32
  _FP8_MIN = torch.finfo(_FP8_DTYPE).min
36
33
  _FP8_MAX = torch.finfo(_FP8_DTYPE).max
37
- _FP8_IS_INT = False
38
34
  except AttributeError:
39
- _FP8_DTYPE = torch.int8
40
- _FP8_MIN, _FP8_MAX = -127, 127
41
- _FP8_IS_INT = True
42
- logger.warning_once(
43
- "torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
44
- )
35
+ _FP8_DTYPE = None
36
+ _FP8_MIN, _FP8_MAX = -448, 448
37
+ logger.warning_once("torch.float8_e4m3fn not available")
45
38
 
46
39
 
47
40
  # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@@ -618,7 +611,7 @@ def replace_with_fp8_linear(
618
611
  # we need this to correctly materialize the weights during quantization
619
612
  module_kwargs = {} if pre_quantized else {"dtype": None}
620
613
  new_module = None
621
- with init_empty_weights():
614
+ with torch.device("meta"):
622
615
  if module_name.endswith(".experts"):
623
616
  new_module = FP8Expert(
624
617
  config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
@@ -701,10 +694,7 @@ class Fp8Quantize(ConversionOps):
701
694
  scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
702
695
  scaled = reshaped * scales_broadcast
703
696
 
704
- if _FP8_IS_INT:
705
- quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
706
- else:
707
- quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
697
+ quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
708
698
 
709
699
  quantized = quantized.reshape(original_shape)
710
700
 
@@ -20,8 +20,8 @@ def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtyp
20
20
  else torch.get_autocast_gpu_dtype()
21
21
  )
22
22
  # Handle the case where the model is quantized
23
- elif hasattr(module.config, "_pre_quantization_dtype"):
24
- return module.config._pre_quantization_dtype
23
+ elif hasattr(module.config, "quantization_config"):
24
+ return module.config.dtype
25
25
  else:
26
26
  return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
27
27
  return None
@@ -16,12 +16,9 @@
16
16
  from math import sqrt
17
17
 
18
18
  from ..quantizers.quantizers_utils import should_convert_module
19
- from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging
19
+ from ..utils import is_flute_available, is_hadamard_available, is_torch_available, logging
20
20
 
21
21
 
22
- if is_accelerate_available():
23
- from accelerate import init_empty_weights
24
-
25
22
  if is_torch_available():
26
23
  import torch
27
24
  import torch.nn as nn
@@ -569,7 +566,7 @@ def replace_with_higgs_linear(model, modules_to_not_convert: list[str] | None =
569
566
  for module_name, module in model.named_modules():
570
567
  if not should_convert_module(module_name, modules_to_not_convert):
571
568
  continue
572
- with init_empty_weights():
569
+ with torch.device("meta"):
573
570
  if isinstance(module, nn.Linear):
574
571
  new_module = HiggsLinear(
575
572
  module.in_features,
@@ -11,11 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib.metadata
14
15
  import os
15
16
  import re
16
17
  from collections.abc import Callable
17
18
  from types import ModuleType
18
19
 
20
+ from packaging import version as pkg_version
21
+
19
22
  from ..utils import ENV_VARS_TRUE_VALUES, logging
20
23
  from ..utils.import_utils import is_kernels_available
21
24
  from .flash_attention import flash_attention_forward
@@ -28,10 +31,12 @@ try:
28
31
  Device,
29
32
  LayerRepository,
30
33
  Mode,
31
- get_kernel,
32
34
  register_kernel_mapping,
33
35
  replace_kernel_forward_from_hub,
34
36
  )
37
+ from kernels import (
38
+ get_kernel as get_kernel_hub,
39
+ )
35
40
  from kernels import (
36
41
  use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
37
42
  )
@@ -340,8 +345,6 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
340
345
  mapping[kernel_name] = None
341
346
  return None
342
347
  if _kernels_available:
343
- from kernels import get_kernel
344
-
345
348
  try:
346
349
  repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
347
350
  revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
@@ -370,7 +373,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
370
373
  if callable(is_kernel_available) and is_kernel_available():
371
374
  # Try to import the module "{kernel_name}" from parent package level
372
375
  try:
373
- module = importlib.import_module(f"{kernel_name}")
376
+ module = importlib.import_module(f"{new_kernel_name}")
374
377
  mapping[kernel_name] = module
375
378
  return module
376
379
  except Exception:
@@ -381,6 +384,20 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
381
384
  return mapping[kernel_name]
382
385
 
383
386
 
387
+ def get_kernel(kernel_name: str, revision: str | None = None, version: str | None = None) -> ModuleType:
388
+ from .. import __version__
389
+
390
+ user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name}
391
+ if _kernels_available:
392
+ kernels_version = importlib.metadata.version("kernels")
393
+ if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"):
394
+ return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent)
395
+ else:
396
+ return get_kernel_hub(kernel_name, revision=revision)
397
+ else:
398
+ raise ImportError("kernels is not installed, please install it with `pip install kernels`")
399
+
400
+
384
401
  def use_kernelized_func(module_names: list[Callable] | Callable):
385
402
  """
386
403
  This decorator attaches the target function as an attribute of the module.
@@ -415,5 +432,6 @@ __all__ = [
415
432
  "register_kernel_mapping_transformers",
416
433
  "replace_kernel_forward_from_hub",
417
434
  "lazy_load_kernel",
435
+ "get_kernel",
418
436
  "use_kernelized_func",
419
- ]
437
+ ] # type: ignore