transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -68,13 +68,12 @@ class BitNetHfQuantizer(HfQuantizer):
68
68
  def _process_model_before_weight_loading(
69
69
  self,
70
70
  model: "PreTrainedModel",
71
- keep_in_fp32_modules: list[str] | None = None,
72
71
  **kwargs,
73
72
  ):
74
73
  from ..integrations import replace_with_bitnet_linear
75
74
 
76
75
  self.modules_to_not_convert = self.get_modules_to_not_convert(
77
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
76
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
78
77
  )
79
78
 
80
79
  model = replace_with_bitnet_linear(
@@ -87,10 +86,6 @@ class BitNetHfQuantizer(HfQuantizer):
87
86
  max_memory = {key: val * 0.90 for key, val in max_memory.items()}
88
87
  return max_memory
89
88
 
90
- def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
91
- target_dtype = torch.int8
92
- return target_dtype
93
-
94
89
  def is_serializable(self):
95
90
  return True
96
91
 
@@ -51,15 +51,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
51
51
  def __init__(self, quantization_config, **kwargs):
52
52
  super().__init__(quantization_config, **kwargs)
53
53
 
54
- # This describes the additional items that are saved on the state dict (on the params themselves)
55
- self.bnb_keys = [
56
- f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
57
- "absmax",
58
- "quant_map",
59
- ]
60
- if self.quantization_config.bnb_4bit_use_double_quant:
61
- self.bnb_keys.extend(["nested_absmax", "nested_quant_map"])
62
-
63
54
  def validate_environment(self, *args, **kwargs):
64
55
  if not is_accelerate_available():
65
56
  raise ImportError(
@@ -87,55 +78,25 @@ class Bnb4BitHfQuantizer(HfQuantizer):
87
78
  "for more details. "
88
79
  )
89
80
 
90
- def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
91
- from accelerate.utils import CustomDtype
81
+ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
82
+ "Return the element size (in bytes) for `param_name`."
83
+ if self.param_needs_quantization(model, param_name):
84
+ # 4 bit
85
+ return 0.5
92
86
 
93
- if target_dtype != torch.int8:
94
- logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
95
- return CustomDtype.INT4
87
+ return super().param_element_size(model, param_name, param)
96
88
 
97
89
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
98
90
  import bitsandbytes as bnb
99
91
 
100
- # TODO: maybe remove
101
- # # They are on the params themselves, so we cannot easily extract the module from the name
102
- if any(param_name.endswith(x) for x in self.bnb_keys):
103
- return True
104
92
  module, name = get_module_from_name(model, param_name)
105
93
  return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
106
94
 
107
- def get_param_name(self, param_name: str) -> str:
108
- """
109
- Get the right param_name in order to get the module associated with the param.
110
- This is useful for quantized stats lile absmax or quant_map as we need to update the param_name to get the module as they are stored in ...weight.absmax.
111
- """
112
- if self.pre_quantized:
113
- # We need to get the param name of quantized weights and not its components. Otherwise, we won't be able to get the nn.Module associated.
114
- if any(param_name.endswith(x) for x in self.bnb_keys):
115
- param_name = (
116
- param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
117
- )
118
- return param_name
119
-
120
95
  def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
121
96
  # need more space for buffers that are created during quantization
122
97
  max_memory = {key: val * 0.90 for key, val in max_memory.items()}
123
98
  return max_memory
124
99
 
125
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
126
- # TODO: remove ? is it still true ? we will move to dtype = "auto" so it will likely be either fp16 or bf16
127
- if dtype is None:
128
- # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
129
- logger.info(
130
- "Overriding dtype=%s with `dtype=torch.float16` due to "
131
- "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
132
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
133
- " dtype=torch.float16 to remove this warning.",
134
- dtype,
135
- )
136
- dtype = torch.float16
137
- return dtype
138
-
139
100
  def update_device_map(self, device_map):
140
101
  if device_map is None:
141
102
  if torch.cuda.is_available():
@@ -159,13 +120,12 @@ class Bnb4BitHfQuantizer(HfQuantizer):
159
120
  self,
160
121
  model: "PreTrainedModel",
161
122
  device_map,
162
- keep_in_fp32_modules: list[str] | None = None,
163
123
  **kwargs,
164
124
  ):
165
125
  from ..integrations import replace_with_bnb_linear
166
126
 
167
127
  self.modules_to_not_convert = self.get_modules_to_not_convert(
168
- model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
128
+ model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
169
129
  )
170
130
 
171
131
  if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
@@ -192,10 +152,10 @@ class Bnb4BitHfQuantizer(HfQuantizer):
192
152
  def is_trainable(self) -> bool:
193
153
  return True
194
154
 
195
- def _dequantize(self, model):
155
+ def _dequantize(self, model, dtype=None):
196
156
  from ..integrations import dequantize_and_replace
197
157
 
198
- model = dequantize_and_replace(model, quantization_config=self.quantization_config)
158
+ model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
199
159
  return model
200
160
 
201
161
  def get_quantize_ops(self):
@@ -83,19 +83,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
83
83
  max_memory = {key: val * 0.90 for key, val in max_memory.items()}
84
84
  return max_memory
85
85
 
86
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
87
- if dtype is None:
88
- # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
89
- logger.info(
90
- "Overriding dtype=%s with `dtype=torch.float16` due to "
91
- "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
92
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
93
- " dtype=torch.float16 to remove this warning.",
94
- dtype,
95
- )
96
- dtype = torch.float16
97
- return dtype
98
-
99
86
  def update_device_map(self, device_map):
100
87
  if device_map is None:
101
88
  if torch.cuda.is_available():
@@ -115,8 +102,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
115
102
  )
116
103
  return device_map
117
104
 
118
- def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
119
- return torch.int8
105
+ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
106
+ "Return the element size (in bytes) for `param_name`."
107
+ if self.param_needs_quantization(model, param_name):
108
+ # 8-bit
109
+ return 1
110
+ return super().param_element_size(model, param_name, param)
120
111
 
121
112
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
122
113
  import bitsandbytes as bnb
@@ -133,13 +124,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
133
124
  self,
134
125
  model: "PreTrainedModel",
135
126
  device_map,
136
- keep_in_fp32_modules: list[str] | None = None,
137
127
  **kwargs,
138
128
  ):
139
129
  from ..integrations import replace_with_bnb_linear
140
130
 
141
131
  self.modules_to_not_convert = self.get_modules_to_not_convert(
142
- model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
132
+ model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
143
133
  )
144
134
 
145
135
  if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
@@ -161,10 +151,10 @@ class Bnb8BitHfQuantizer(HfQuantizer):
161
151
  def is_trainable(self) -> bool:
162
152
  return True
163
153
 
164
- def _dequantize(self, model):
154
+ def _dequantize(self, model, dtype=None):
165
155
  from ..integrations import dequantize_and_replace
166
156
 
167
- model = dequantize_and_replace(model, quantization_config=self.quantization_config)
157
+ model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
168
158
  return model
169
159
 
170
160
  def get_quantize_ops(self):
@@ -59,10 +59,7 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
59
59
  )
60
60
 
61
61
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
62
- if dtype is None:
63
- logger.info("Loading model using torch.float16 for compressed-tensors quantization")
64
- dtype = torch.float16
65
- elif dtype != torch.float16:
62
+ if dtype != torch.float16:
66
63
  logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with compressed_tensors.")
67
64
  return dtype
68
65
 
@@ -64,16 +64,7 @@ class EetqHfQuantizer(HfQuantizer):
64
64
  )
65
65
 
66
66
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
67
- if dtype is None:
68
- dtype = torch.float16
69
- logger.info(
70
- "Overriding dtype=%s with `dtype=torch.float16` due to "
71
- "requirements of `eetq` to enable model loading in 8-bit. "
72
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
73
- " dtype=torch.float16 to remove this warning.",
74
- dtype,
75
- )
76
- elif dtype != torch.float16:
67
+ if dtype != torch.float16:
77
68
  logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.")
78
69
  return dtype
79
70
 
@@ -92,13 +83,12 @@ class EetqHfQuantizer(HfQuantizer):
92
83
  def _process_model_before_weight_loading(
93
84
  self,
94
85
  model: "PreTrainedModel",
95
- keep_in_fp32_modules: list[str] | None = None,
96
86
  **kwargs,
97
87
  ):
98
88
  from ..integrations import replace_with_eetq_linear
99
89
 
100
90
  self.modules_to_not_convert = self.get_modules_to_not_convert(
101
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
91
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
102
92
  )
103
93
 
104
94
  model = replace_with_eetq_linear(
@@ -84,19 +84,11 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
84
84
  )
85
85
 
86
86
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
87
- if dtype is None:
88
- dtype = torch.bfloat16
89
- logger.info(
90
- "Overriding dtype=%s with `dtype=torch.bloat16` due to "
91
- "requirements of `fbgemm-gpu` to enable model loading in fp8. "
92
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
93
- " dtype=torch.bfloat16 to remove this warning.",
94
- dtype,
95
- )
96
- elif dtype == torch.float16:
97
- raise ValueError(
98
- "You cannot use FP8 with dtype=torch.float16. We recommend you passing dtype=torch.bfloat16"
87
+ if dtype != torch.bfloat16:
88
+ logger.warning_once(
89
+ f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
99
90
  )
91
+ dtype = torch.bfloat16
100
92
  return dtype
101
93
 
102
94
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
@@ -119,13 +111,12 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
119
111
  def _process_model_before_weight_loading(
120
112
  self,
121
113
  model: "PreTrainedModel",
122
- keep_in_fp32_modules: list[str] | None = None,
123
114
  **kwargs,
124
115
  ):
125
116
  from ..integrations import replace_with_fbgemm_fp8_linear
126
117
 
127
118
  self.modules_to_not_convert = self.get_modules_to_not_convert(
128
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
119
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
129
120
  )
130
121
 
131
122
  model = replace_with_fbgemm_fp8_linear(
@@ -33,7 +33,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
33
33
  return
34
34
 
35
35
  if not torch.cuda.is_available() and not is_torch_xpu_available():
36
- if self.pre_quantized and not self.quantization_config.dequantize:
36
+ if self.pre_quantized:
37
37
  logger.warning_once(
38
38
  "Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
39
39
  )
@@ -46,10 +46,13 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
46
46
  compute_capability = torch.cuda.get_device_capability()
47
47
  major, minor = compute_capability
48
48
  if (major < 8) or (major == 8 and minor < 9):
49
- raise ValueError(
49
+ logger.warning_once(
50
50
  "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
51
- f", actual = `{major}.{minor}`"
51
+ f", actual = `{major}.{minor}`. We will default to dequantizing the model to bf16. Feel free "
52
+ f"to use a different quantization method like bitsandbytes or torchao"
52
53
  )
54
+ self.quantization_config.dequantize = True
55
+ return
53
56
 
54
57
  device_map = kwargs.get("device_map")
55
58
  if device_map is None:
@@ -82,16 +85,22 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
82
85
  return True
83
86
  return False
84
87
 
88
+ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
89
+ "Return the element size (in bytes) for `param_name`."
90
+ if self.param_needs_quantization(model, param_name):
91
+ # 8 bit, this is neeed as when `pre_quantized`` is False, we don't set the dtype of the FP8Linear in order to correctly load the weights
92
+ return 1
93
+ return super().param_element_size(model, param_name, param)
94
+
85
95
  def _process_model_before_weight_loading(
86
96
  self,
87
97
  model: "PreTrainedModel",
88
- keep_in_fp32_modules: list[str] | None = None,
89
98
  **kwargs,
90
99
  ):
91
100
  from ..integrations.finegrained_fp8 import replace_with_fp8_linear
92
101
 
93
102
  self.modules_to_not_convert = self.get_modules_to_not_convert(
94
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
103
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
95
104
  )
96
105
 
97
106
  model = replace_with_fp8_linear(
@@ -103,7 +112,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
103
112
 
104
113
  # NOTE: TP is applied before quantization so this is only to add hooks.
105
114
  # Quantization is incompatible with DTensors, so we have to anyway have
106
- # gathers! But it should be model independant -> figure out where to put
115
+ # gathers! But it should be model independent -> figure out where to put
107
116
  # the gather and that's it.
108
117
  def update_tp_plan(self, config):
109
118
  if "Qwen3" in config.__class__.__name__:
@@ -137,10 +146,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
137
146
  def is_trainable(self) -> bool:
138
147
  return False
139
148
 
140
- def get_accelerator_warm_up_factor(self):
141
- # Pre-processing is done cleanly, so we can allocate everything here
142
- return 2
143
-
144
149
  def get_quantize_ops(self):
145
150
  from ..integrations.finegrained_fp8 import Fp8Quantize
146
151
 
@@ -78,11 +78,11 @@ class FPQuantHfQuantizer(HfQuantizer):
78
78
  )
79
79
 
80
80
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
81
- if dtype is None:
82
- logger.info("`dtype` is None. Setting `dtype=torch.bfloat16` for qutlass compatibility.")
81
+ if dtype != torch.bfloat16:
82
+ logger.warning_once(
83
+ f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
84
+ )
83
85
  dtype = torch.bfloat16
84
- elif dtype != torch.bfloat16:
85
- raise ValueError(f"Invalid `dtype` {dtype}. fp_quant quantization only supports `dtype=torch.bfloat16`.")
86
86
  return dtype
87
87
 
88
88
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
@@ -66,10 +66,7 @@ class GptqHfQuantizer(HfQuantizer):
66
66
  raise ImportError("The gptqmodel version should be >= 1.4.3, optimum version should >= 1.24.0")
67
67
 
68
68
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
69
- if dtype is None:
70
- dtype = torch.float16
71
- logger.info("Loading the model in `torch.float16`. To overwrite it, set `dtype` manually.")
72
- elif dtype != torch.float16:
69
+ if dtype != torch.float16:
73
70
  logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with GPTQ.")
74
71
  return dtype
75
72
 
@@ -69,10 +69,7 @@ class HiggsHfQuantizer(HfQuantizer):
69
69
  )
70
70
 
71
71
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
72
- if dtype is None:
73
- logger.info("`dtype` is None. Setting `dtype=torch.float16` for FLUTE compatibility.")
74
- dtype = torch.float16
75
- elif dtype != torch.float16 and dtype != torch.bfloat16:
72
+ if dtype != torch.float16 and dtype != torch.bfloat16:
76
73
  raise ValueError(
77
74
  f"Invalid `dtype` {dtype}. HIGGS quantization only supports `dtype=torch.float16` or `dtype=torch.bfloat16`."
78
75
  )
@@ -116,13 +113,12 @@ class HiggsHfQuantizer(HfQuantizer):
116
113
  def _process_model_before_weight_loading(
117
114
  self,
118
115
  model: "PreTrainedModel",
119
- keep_in_fp32_modules: list[str] | None = None,
120
116
  **kwargs,
121
117
  ):
122
118
  from ..integrations import replace_with_higgs_linear
123
119
 
124
120
  self.modules_to_not_convert = self.get_modules_to_not_convert(
125
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
121
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
126
122
  )
127
123
 
128
124
  replace_with_higgs_linear(
@@ -53,7 +53,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
53
53
  """Lazy import and initialize kernels only when needed"""
54
54
  if self.triton_kernels_hub is None:
55
55
  try:
56
- from kernels import get_kernel
56
+ from ..integrations.hub_kernels import get_kernel
57
57
 
58
58
  self.triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
59
59
  except ImportError:
@@ -135,18 +135,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
135
135
  "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
136
136
  )
137
137
 
138
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
139
- if dtype is None:
140
- dtype = torch.bfloat16
141
- logger.info(
142
- "Overriding dtype=%s with `dtype=torch.bfloat16` due to "
143
- "requirements of `fbgemm-gpu` to enable model loading in fp4. "
144
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
145
- " dtype=torch.bfloat16 to remove this warning.",
146
- dtype,
147
- )
148
- return dtype
149
-
150
138
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
151
139
  from ..integrations import Mxfp4GptOssExperts
152
140
 
@@ -167,7 +155,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
167
155
  def _process_model_before_weight_loading(
168
156
  self,
169
157
  model: "PreTrainedModel",
170
- keep_in_fp32_modules: list[str] | None = None,
171
158
  use_kernels: bool = False,
172
159
  **kwargs,
173
160
  ):
@@ -182,7 +169,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
182
169
  self.quantization_config.dequantize = True
183
170
 
184
171
  self.modules_to_not_convert = self.get_modules_to_not_convert(
185
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
172
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
186
173
  )
187
174
 
188
175
  model = replace_with_mxfp4_linear(
@@ -215,19 +202,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
215
202
  )
216
203
  return config
217
204
 
218
- def get_param_name(self, param_name: str) -> str:
219
- if self.quantization_config.dequantize:
220
- if "_blocks" in param_name:
221
- return param_name.replace("_blocks", "")
222
- elif "_scales" in param_name:
223
- return param_name.replace("_scales", "")
224
- elif not self.pre_quantized:
225
- if param_name.endswith("gate_up_proj"):
226
- return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
227
- if param_name.endswith("down_proj"):
228
- return param_name.replace("down_proj", "down_proj_blocks")
229
- return param_name
230
-
231
205
  def get_state_dict_and_metadata(self, model):
232
206
  from ..integrations import Mxfp4GptOssExperts
233
207
 
@@ -44,6 +44,13 @@ class QuantoHfQuantizer(HfQuantizer):
44
44
 
45
45
  def __init__(self, quantization_config: QuantoConfig, **kwargs):
46
46
  super().__init__(quantization_config, **kwargs)
47
+ map_to_param_size = {
48
+ "int8": 1,
49
+ "float8": 1,
50
+ "int4": 0.5,
51
+ "int2": 0.25,
52
+ }
53
+ self.quantized_param_size = map_to_param_size.get(self.quantization_config.weights, None)
47
54
 
48
55
  def validate_environment(self, *args, **kwargs):
49
56
  if not is_optimum_quanto_available():
@@ -83,25 +90,18 @@ class QuantoHfQuantizer(HfQuantizer):
83
90
  max_memory = {key: val * 0.90 for key, val in max_memory.items()}
84
91
  return max_memory
85
92
 
86
- def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
87
- from accelerate.utils import CustomDtype
93
+ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
94
+ "Return the element size (in bytes) for `param_name`."
95
+ if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
96
+ return self.quantized_param_size
88
97
 
89
- mapping = {
90
- "int8": torch.int8,
91
- "float8": CustomDtype.FP8,
92
- "int4": CustomDtype.INT4,
93
- "int2": CustomDtype.INT2,
94
- }
95
- target_dtype = mapping[self.quantization_config.weights]
96
- return target_dtype
98
+ return super().param_element_size(model, param_name, param)
97
99
 
98
- def _process_model_before_weight_loading(
99
- self, model: "PreTrainedModel", keep_in_fp32_modules: list[str] | None = None, **kwargs
100
- ):
100
+ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
101
101
  from ..integrations import replace_with_quanto_layers
102
102
 
103
103
  self.modules_to_not_convert = self.get_modules_to_not_convert(
104
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
104
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
105
105
  )
106
106
 
107
107
  model = replace_with_quanto_layers(
@@ -51,24 +51,19 @@ class SpQRHfQuantizer(HfQuantizer):
51
51
  raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`")
52
52
 
53
53
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
54
- if dtype is None:
55
- dtype = torch.float16
56
- logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.")
57
- elif dtype != torch.float16:
54
+ if dtype != torch.float16:
58
55
  raise ValueError(
59
- "You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"
60
- "torch.float16 explicitly."
56
+ "You cannot use any type other than torch.float16 for SpQR. Please set it totorch.float16 explicitly."
61
57
  )
62
58
  return dtype
63
59
 
64
60
  def _process_model_before_weight_loading(
65
61
  self,
66
62
  model: "PreTrainedModel",
67
- keep_in_fp32_modules: list[str] | None = None,
68
63
  **kwargs,
69
64
  ):
70
65
  self.modules_to_not_convert = self.get_modules_to_not_convert(
71
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
66
+ model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
72
67
  )
73
68
  replace_with_spqr_linear(
74
69
  model,