transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,8 @@ from ..utils import (
25
25
  is_accelerate_available,
26
26
  is_bitsandbytes_available,
27
27
  is_torch_available,
28
+ is_torch_hpu_available,
29
+ is_torch_npu_available,
28
30
  is_torch_xpu_available,
29
31
  logging,
30
32
  )
@@ -35,34 +37,20 @@ if is_torch_available():
35
37
  import torch
36
38
 
37
39
  from ..core_model_loading import WeightConverter
38
- from ..pytorch_utils import Conv1D
39
40
 
40
41
  logger = logging.get_logger(__name__)
41
42
 
42
43
 
43
44
  class Bnb8BitHfQuantizer(HfQuantizer):
44
45
  """
45
- 8-bit quantization from bitsandbytes quantization method:
46
- before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
47
- layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
48
- saving:
49
- from state dict, as usual; saves weights and 'SCB' component
50
- loading:
51
- need to locate SCB component and pass to the Linear8bitLt object
46
+ 8-bit quantization from bitsandbytes quantization method
52
47
  """
53
48
 
54
- use_keep_in_fp32_modules = True
55
- requires_parameters_quantization = True
56
49
  requires_calibration = False
57
50
 
58
- required_packages = ["bitsandbytes", "accelerate"]
59
-
60
51
  def __init__(self, quantization_config, **kwargs):
61
52
  super().__init__(quantization_config, **kwargs)
62
53
 
63
- if self.quantization_config.llm_int8_skip_modules is not None:
64
- self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
65
-
66
54
  def validate_environment(self, *args, **kwargs):
67
55
  if not is_accelerate_available():
68
56
  raise ImportError(
@@ -78,17 +66,9 @@ class Bnb8BitHfQuantizer(HfQuantizer):
78
66
  validate_bnb_backend_availability(raise_exception=True)
79
67
 
80
68
  device_map = kwargs.get("device_map")
81
- if (
82
- device_map is not None
83
- and isinstance(device_map, dict)
84
- and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
85
- ):
86
- device_map_without_lm_head = {
87
- key: device_map[key] for key in device_map if key not in self.modules_to_not_convert
88
- }
89
- if set(device_map.values()) == {"cpu"}:
90
- pass
91
- elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
69
+ if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
70
+ values = set(device_map.values())
71
+ if values != {"cpu"} and ("cpu" in values or "disk" in values):
92
72
  raise ValueError(
93
73
  "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
94
74
  "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
@@ -120,6 +100,10 @@ class Bnb8BitHfQuantizer(HfQuantizer):
120
100
  if device_map is None:
121
101
  if torch.cuda.is_available():
122
102
  device_map = {"": torch.cuda.current_device()}
103
+ elif is_torch_npu_available():
104
+ device_map = {"": f"npu:{torch.npu.current_device()}"}
105
+ elif is_torch_hpu_available():
106
+ device_map = {"": f"hpu:{torch.hpu.current_device()}"}
123
107
  elif is_torch_xpu_available():
124
108
  device_map = {"": torch.xpu.current_device()}
125
109
  else:
@@ -132,61 +116,14 @@ class Bnb8BitHfQuantizer(HfQuantizer):
132
116
  return device_map
133
117
 
134
118
  def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
135
- if target_dtype != torch.int8:
136
- logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
137
119
  return torch.int8
138
120
 
139
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
140
- bnb_keys = ["SCB", "weight_format"]
141
- return [k for k in unexpected_keys if not any(k.endswith(x) for x in bnb_keys)]
142
-
143
121
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
144
122
  import bitsandbytes as bnb
145
123
 
146
124
  module, name = get_module_from_name(model, param_name)
147
125
  return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
148
126
 
149
- def create_quantized_param(
150
- self,
151
- model: "PreTrainedModel",
152
- param_value: "torch.Tensor",
153
- param_name: str,
154
- target_device: "torch.device",
155
- **kwargs,
156
- ):
157
- import bitsandbytes as bnb
158
-
159
- module, tensor_name = get_module_from_name(model, param_name)
160
-
161
- if self.pre_quantized and not self.is_serializable():
162
- raise ValueError(
163
- "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
164
- "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
165
- )
166
- # Those 2 can only happen when self.pre_quantized == True
167
- if tensor_name == "SCB":
168
- setattr(module.weight, "SCB", param_value.to(target_device))
169
- return
170
- # It's not used, but it's getting serialized for BC reason...
171
- elif tensor_name == "weight_format":
172
- return
173
-
174
- # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
175
- # Since weights are saved in the correct "orientation", we skip transposing when loading.
176
- if issubclass(module.source_cls, Conv1D) and not self.pre_quantized:
177
- param_value = param_value.T
178
-
179
- old_value = getattr(module, tensor_name)
180
- kwargs = old_value.__dict__
181
- kwargs.pop("_is_hf_initialized", None)
182
- # Need to pop SCB and reset it because of bnb internals that modifies its value when switching devices ...
183
- SCB = kwargs.pop("SCB", None)
184
- new_value = bnb.nn.Int8Params(param_value.to("cpu"), requires_grad=False, **kwargs).to(target_device)
185
- if SCB is not None:
186
- setattr(new_value, "SCB", SCB)
187
- # Set it to the module
188
- module._parameters[tensor_name] = new_value
189
-
190
127
  def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
191
128
  model.is_loaded_in_8bit = True
192
129
  model.is_8bit_serializable = self.is_serializable()
@@ -201,23 +138,14 @@ class Bnb8BitHfQuantizer(HfQuantizer):
201
138
  ):
202
139
  from ..integrations import replace_with_bnb_linear
203
140
 
204
- llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
205
-
206
141
  self.modules_to_not_convert = self.get_modules_to_not_convert(
207
142
  model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
208
143
  )
209
144
 
210
- # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
211
- if isinstance(device_map, dict) and len(device_map.keys()) > 1:
212
- keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
213
-
214
- if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
215
- raise ValueError(
216
- "If you want to offload some keys to `cpu` or `disk`, you need to set "
217
- "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
218
- " converted to 8-bit but kept in 32-bit."
219
- )
220
- self.modules_to_not_convert.extend(keys_on_cpu)
145
+ if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
146
+ if isinstance(device_map, dict):
147
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
148
+ self.modules_to_not_convert.extend(keys_on_cpu)
221
149
 
222
150
  model = replace_with_bnb_linear(
223
151
  model,
@@ -226,9 +154,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
226
154
  pre_quantized=self.pre_quantized,
227
155
  )
228
156
 
229
- model.config.quantization_config = self.quantization_config
230
-
231
- def is_serializable(self, safe_serialization=None):
157
+ def is_serializable(self):
232
158
  return True
233
159
 
234
160
  @property
@@ -238,9 +164,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
238
164
  def _dequantize(self, model):
239
165
  from ..integrations import dequantize_and_replace
240
166
 
241
- model = dequantize_and_replace(
242
- model, self.modules_to_not_convert, quantization_config=self.quantization_config
243
- )
167
+ model = dequantize_and_replace(model, quantization_config=self.quantization_config)
244
168
  return model
245
169
 
246
170
  def get_quantize_ops(self):
@@ -31,7 +31,6 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
31
31
  """
32
32
 
33
33
  requires_calibration = True
34
- required_packages = ["compressed_tensors"]
35
34
 
36
35
  def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
37
36
  super().__init__(quantization_config, **kwargs)
@@ -58,9 +57,6 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
58
57
  "Using `compressed_tensors` quantized models requires the compressed-tensors library: "
59
58
  "`pip install compressed-tensors`"
60
59
  )
61
- if not is_torch_available():
62
- # torch already should be installed as part of compressed tensors
63
- raise ImportError("torch is required for using compressed-tensors quantization")
64
60
 
65
61
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
66
62
  if dtype is None:
@@ -113,6 +109,6 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
113
109
  # models need to be decompressed carry out qat
114
110
  return not self.run_compressed or not self.quantization_config.is_quantization_compressed
115
111
 
116
- def is_serializable(self, safe_serialization=None) -> bool:
112
+ def is_serializable(self) -> bool:
117
113
  """Models quantized using compressed tensors can be saved to disk"""
118
114
  return True
@@ -19,7 +19,7 @@ from .base import HfQuantizer
19
19
  if TYPE_CHECKING:
20
20
  from ..modeling_utils import PreTrainedModel
21
21
 
22
- from ..utils import is_accelerate_available, is_eetq_available, is_torch_available, logging
22
+ from ..utils import is_accelerate_available, is_kernels_available, is_torch_available, logging
23
23
  from .quantizers_utils import get_module_from_name
24
24
 
25
25
 
@@ -32,40 +32,17 @@ logger = logging.get_logger(__name__)
32
32
 
33
33
  class EetqHfQuantizer(HfQuantizer):
34
34
  """
35
- 8-bit quantization from EETQ quantization method:
36
- before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the
37
- layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call
35
+ 8-bit quantization from EETQ quantization method
38
36
  """
39
37
 
40
- requires_parameters_quantization = True
41
38
  requires_calibration = False
42
39
 
43
- required_packages = ["eetq", "accelerate"]
44
-
45
40
  def __init__(self, quantization_config, **kwargs):
46
41
  super().__init__(quantization_config, **kwargs)
47
- self.quantization_config = quantization_config
48
42
 
49
43
  def validate_environment(self, *args, **kwargs):
50
- if not is_eetq_available():
51
- raise ImportError(
52
- "Using `eetq` 8-bit quantization requires eetq."
53
- "Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ"
54
- )
55
-
56
- try:
57
- import eetq # noqa: F401
58
- except ImportError as exc:
59
- if "shard_checkpoint" in str(exc):
60
- # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
61
- # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
62
- # TODO: Update message once eetq releases a fix
63
- raise ImportError(
64
- "You are using a version of EETQ that is incompatible with the current transformers version. "
65
- "Either downgrade transformers to <= v4.46.3 or, if available, upgrade EETQ to > v1.0.0."
66
- ) from exc
67
- else:
68
- raise
44
+ if not is_kernels_available():
45
+ raise ImportError("Loading an EETQ quantized model requires kernels (`pip install kernels`)")
69
46
 
70
47
  if not is_accelerate_available():
71
48
  raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)")
@@ -79,8 +56,8 @@ class EetqHfQuantizer(HfQuantizer):
79
56
  "You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set "
80
57
  "your model on a GPU device in order to run your model."
81
58
  )
82
- elif device_map is not None:
83
- if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
59
+ elif isinstance(device_map, dict):
60
+ if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
84
61
  raise ValueError(
85
62
  "You are attempting to load an EETQ model with a device_map that contains a CPU or disk device."
86
63
  " This is not supported. Please remove the CPU or disk device from the device_map."
@@ -101,7 +78,7 @@ class EetqHfQuantizer(HfQuantizer):
101
78
  return dtype
102
79
 
103
80
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
104
- from eetq import EetqLinear
81
+ from ..integrations.eetq import EetqLinear
105
82
 
106
83
  module, tensor_name = get_module_from_name(model, param_name)
107
84
 
@@ -112,31 +89,6 @@ class EetqHfQuantizer(HfQuantizer):
112
89
  return True
113
90
  return False
114
91
 
115
- def create_quantized_param(
116
- self,
117
- model: "PreTrainedModel",
118
- param_value: "torch.Tensor",
119
- param_name: str,
120
- target_device: "torch.device",
121
- **kwargs,
122
- ):
123
- from eetq import EetqLinear, quantize_and_preprocess_weights
124
-
125
- module, tensor_name = get_module_from_name(model, param_name)
126
- new_value, weight_scale = quantize_and_preprocess_weights(param_value)
127
-
128
- # Samity check
129
- if isinstance(module, EetqLinear):
130
- if self.pre_quantized or tensor_name == "bias":
131
- if tensor_name == "weight" and param_value.dtype != torch.int8:
132
- raise ValueError("Expect quantized weights but got an unquantized weight")
133
- else:
134
- if tensor_name == "weight_scale":
135
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
136
-
137
- module._buffers[tensor_name] = new_value.to(target_device)
138
- module.register("weight_scales", weight_scale.to(target_device))
139
-
140
92
  def _process_model_before_weight_loading(
141
93
  self,
142
94
  model: "PreTrainedModel",
@@ -150,17 +102,17 @@ class EetqHfQuantizer(HfQuantizer):
150
102
  )
151
103
 
152
104
  model = replace_with_eetq_linear(
153
- model,
154
- modules_to_not_convert=self.modules_to_not_convert,
155
- quantization_config=self.quantization_config,
156
- pre_quantized=self.pre_quantized,
105
+ model, modules_to_not_convert=self.modules_to_not_convert, pre_quantized=self.pre_quantized
157
106
  )
158
107
 
159
- model.config.quantization_config = self.quantization_config
160
-
161
- def is_serializable(self, safe_serialization=None):
108
+ def is_serializable(self):
162
109
  return True
163
110
 
164
111
  @property
165
112
  def is_trainable(self) -> bool:
166
113
  return True
114
+
115
+ def get_quantize_ops(self):
116
+ from ..integrations.eetq import EetqQuantize
117
+
118
+ return EetqQuantize(self)
@@ -19,14 +19,21 @@ from .base import HfQuantizer
19
19
  if TYPE_CHECKING:
20
20
  from ..modeling_utils import PreTrainedModel
21
21
 
22
- from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_fbgemm_gpu_available,
25
+ is_kernels_available,
26
+ is_torch_available,
27
+ is_torch_cuda_available,
28
+ is_torch_xpu_available,
29
+ logging,
30
+ )
23
31
  from .quantizers_utils import get_module_from_name
24
32
 
25
33
 
26
34
  if is_torch_available():
27
35
  import torch
28
36
 
29
-
30
37
  logger = logging.get_logger(__name__)
31
38
 
32
39
 
@@ -35,54 +42,41 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
35
42
  FP8 quantization using fbgemm kernels
36
43
  """
37
44
 
38
- requires_parameters_quantization = True
39
45
  requires_calibration = False
40
46
 
41
- required_packages = ["fbgemm-gpu", "accelerate"]
42
-
43
47
  def __init__(self, quantization_config, **kwargs):
44
48
  super().__init__(quantization_config, **kwargs)
45
- self.quantization_config = quantization_config
46
49
 
47
50
  def validate_environment(self, *args, **kwargs):
48
- if not is_torch_available():
49
- raise ImportError(
50
- "Using fbgemm fp8 quantization requires torch >= 2.1.0"
51
- "Please install the latest version of torch ( pip install --upgrade torch )"
52
- )
53
- if not is_fbgemm_gpu_available():
51
+ if not is_torch_cuda_available() and not is_torch_xpu_available():
52
+ raise ImportError("Using fbgemm fp8 quantization requires a GPU or XPU")
53
+ if is_torch_xpu_available() and not is_kernels_available():
54
+ raise ImportError("Using FP8 fbgemm on XPU requires kernels (`pip install kernels`)")
55
+ if is_torch_cuda_available() and not is_fbgemm_gpu_available():
54
56
  raise ImportError(
55
- "Using fbgemm fp8 quantization requires fbgemm-gpu library"
57
+ "Loading an FP8 fbgemm quantized model on CUDA requires fbgemm-gpu library"
56
58
  "Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
57
59
  )
58
-
59
60
  if not is_accelerate_available():
60
61
  raise ImportError(
61
62
  "Loading an FP8 quantized model requires accelerate (`pip install --upgrade accelerate`)"
62
63
  )
63
-
64
- if not torch.cuda.is_available():
65
- raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
66
-
67
- compute_capability = torch.cuda.get_device_capability()
68
- major, minor = compute_capability
69
- if major < 9:
70
- raise ValueError(
71
- "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
72
- )
64
+ if is_torch_cuda_available():
65
+ compute_capability = torch.cuda.get_device_capability()
66
+ major, _ = compute_capability
67
+ if major < 9:
68
+ raise ValueError(
69
+ "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
70
+ )
73
71
 
74
72
  device_map = kwargs.get("device_map")
75
73
  if device_map is None:
76
74
  logger.warning_once(
77
- "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
78
- "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
75
+ "You have loaded an FP8 model on CPU and have a CUDA/XPU device available, make sure to set "
76
+ "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or 'xpu' or 'auto'. "
79
77
  )
80
- elif device_map is not None:
81
- if (
82
- not self.pre_quantized
83
- and isinstance(device_map, dict)
84
- and ("cpu" in device_map.values() or "disk" in device_map.values())
85
- ):
78
+ elif isinstance(device_map, dict):
79
+ if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
86
80
  raise ValueError(
87
81
  "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
88
82
  "This is not supported when the model is quantized on the fly. "
@@ -101,7 +95,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
101
95
  )
102
96
  elif dtype == torch.float16:
103
97
  raise ValueError(
104
- "You cannot use FP8 with dtype=torch.float16.We recommend you passing dtype=torch.bfloat16"
98
+ "You cannot use FP8 with dtype=torch.float16. We recommend you passing dtype=torch.bfloat16"
105
99
  )
106
100
  return dtype
107
101
 
@@ -122,76 +116,6 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
122
116
  return True
123
117
  return False
124
118
 
125
- def create_quantized_param(
126
- self,
127
- model: "PreTrainedModel",
128
- param_value: "torch.Tensor",
129
- param_name: str,
130
- target_device: "torch.device",
131
- **kwargs,
132
- ):
133
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
134
-
135
- module, tensor_name = get_module_from_name(model, param_name)
136
-
137
- # Sanity checks
138
- if isinstance(module, FbgemmFp8Linear):
139
- if self.pre_quantized or tensor_name == "bias":
140
- if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
141
- raise ValueError("Expect quantized weights but got an unquantized weight")
142
- else:
143
- if tensor_name == "weight_scale":
144
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
145
- if isinstance(module, FbgemmFp8Llama4TextExperts):
146
- if not (self.pre_quantized or tensor_name == "bias"):
147
- if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
148
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
149
-
150
- if isinstance(module, FbgemmFp8Llama4TextExperts):
151
- if tensor_name == "gate_up_proj":
152
- # Process each expert separately
153
- # Transpose the second and third dimension
154
- transposed_param = param_value.transpose(1, 2)
155
-
156
- # Reshape to 2D for quantization
157
- original_shape = transposed_param.shape
158
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
159
-
160
- # Quantize using per row instead of per column
161
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
162
-
163
- # Reshape back to original dimensions
164
- new_value = new_value_flat.reshape(original_shape)
165
- new_value = new_value.transpose(1, 2)
166
- weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
167
- elif tensor_name == "down_proj":
168
- # Process each expert separately
169
- # Transpose the weights for proper quantization
170
- transposed_param = param_value.transpose(1, 2)
171
-
172
- # Reshape to 2D for quantization
173
- original_shape = transposed_param.shape
174
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
175
-
176
- # Quantize using per column
177
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
178
-
179
- # Reshape back to original dimensions
180
- new_value = new_value_flat.reshape(original_shape)
181
- new_value = new_value.transpose(1, 2)
182
- weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
183
-
184
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
185
- else:
186
- new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
187
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
188
- weight_scale.view(weight_scale.shape[0], 1).to(target_device)
189
- )
190
-
191
- module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
192
-
193
- del param_name
194
-
195
119
  def _process_model_before_weight_loading(
196
120
  self,
197
121
  model: "PreTrainedModel",
@@ -200,38 +124,18 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
200
124
  ):
201
125
  from ..integrations import replace_with_fbgemm_fp8_linear
202
126
 
203
- tp_plan = model._tp_plan
204
127
  self.modules_to_not_convert = self.get_modules_to_not_convert(
205
128
  model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
206
129
  )
207
130
 
208
- config = model.config
209
131
  model = replace_with_fbgemm_fp8_linear(
210
132
  model,
211
133
  modules_to_not_convert=self.modules_to_not_convert,
212
134
  quantization_config=self.quantization_config,
213
135
  pre_quantized=self.pre_quantized,
214
- config=config,
215
- tp_plan=tp_plan,
136
+ tp_plan=model._tp_plan,
216
137
  )
217
138
 
218
- model.config.quantization_config = self.quantization_config
219
-
220
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
221
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
222
-
223
- not_missing_keys = []
224
- for name, module in model.named_modules():
225
- if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)):
226
- for missing in missing_keys:
227
- if (
228
- (name in missing or name in f"{prefix}.{missing}")
229
- and not missing.endswith(".weight")
230
- and not missing.endswith(".bias")
231
- ):
232
- not_missing_keys.append(missing)
233
- return [k for k in missing_keys if k not in not_missing_keys]
234
-
235
139
  def update_tp_plan(self, config):
236
140
  if "Llama4" in config.__class__.__name__:
237
141
  text_plan = {
@@ -279,9 +183,14 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
279
183
 
280
184
  return config
281
185
 
282
- def is_serializable(self, safe_serialization=None):
186
+ def is_serializable(self):
283
187
  return True
284
188
 
285
189
  @property
286
190
  def is_trainable(self) -> bool:
287
191
  return False
192
+
193
+ def get_quantize_ops(self):
194
+ from ..integrations.fbgemm_fp8 import FbgemmFp8Quantize
195
+
196
+ return FbgemmFp8Quantize(self)