transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,21 +13,14 @@
13
13
  # limitations under the License.
14
14
  "AWQ (Activation aware Weight Quantization) integration file"
15
15
 
16
- import importlib
16
+ from typing import Optional, Union
17
17
 
18
- from packaging import version
18
+ from ..quantizers.quantizers_utils import should_convert_module
19
+ from ..utils import is_accelerate_available, is_torch_available, logging
19
20
 
20
- from ..activations import ACT2FN
21
- from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS
22
- from ..modeling_utils import PreTrainedModel
23
- from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
24
- from ..utils.quantization_config import (
25
- AwqBackendPackingMethod,
26
- AwqConfig,
27
- AWQLinearVersion,
28
- ExllamaVersion,
29
- )
30
21
 
22
+ if is_accelerate_available():
23
+ from accelerate import init_empty_weights
31
24
 
32
25
  if is_torch_available():
33
26
  import torch
@@ -35,44 +28,6 @@ if is_torch_available():
35
28
 
36
29
  logger = logging.get_logger(__name__)
37
30
 
38
- AWQ_FUSED_MAPPINGS = {
39
- "mistral": {
40
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
41
- "mlp": ["gate_proj", "up_proj", "down_proj"],
42
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
43
- "use_alibi": False,
44
- },
45
- "mixtral": {
46
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
47
- "mlp": ["w1", "w3", "w2"],
48
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
49
- "use_alibi": False,
50
- },
51
- "llama": {
52
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
53
- "mlp": ["gate_proj", "up_proj", "down_proj"],
54
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
55
- "use_alibi": False,
56
- },
57
- "llava": {
58
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
59
- "mlp": ["gate_proj", "up_proj", "down_proj"],
60
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
61
- "use_alibi": False,
62
- },
63
- "qwen2": {
64
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
65
- "mlp": ["gate_proj", "up_proj", "down_proj"],
66
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
67
- "use_alibi": False,
68
- },
69
- "qwen3": {
70
- "attention": ["q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"],
71
- "mlp": ["gate_proj", "up_proj", "down_proj"],
72
- "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
73
- "use_alibi": False,
74
- },
75
- }
76
31
 
77
32
  AWQ_SCALES_MAPPINGS = {
78
33
  "starcoder2": {"act": "act", "layer_before_act": "c_fc"},
@@ -86,55 +41,8 @@ AWQ_SCALES_MAPPINGS = {
86
41
  }
87
42
 
88
43
 
89
- if is_auto_awq_available():
90
- from awq.modules.fused.attn import RoPE
91
-
92
- class AWQRoPE(RoPE):
93
- """
94
- AWQRoPE module for hacking rope implementation in AWQ fused attention modules to support more models.
95
-
96
- Args:
97
- rope_type (`str`):
98
- The rope type to use.
99
- head_dim (`int`):
100
- The head dimension.
101
- max_seq_len (`int`):
102
- The maximum sequence length.
103
- config (`PreTrainedConfig`):
104
- The model config object.
105
- device (`torch.device`):
106
- The device to put the module on.
107
- """
108
-
109
- def __init__(self, rope_type, head_dim, max_seq_len, config, device):
110
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
111
- self.inv_freq, self.attention_scaling = rope_init_fn(config, device)
112
- # Use fake rope_theta to initialize the parent class
113
- super().__init__(head_dim=head_dim, max_seq_len=max_seq_len, device=device, rope_theta=-1)
114
-
115
- def precompute_freqs_cis(self, dim: int, end: int, theta=-1):
116
- t = torch.arange(end, device=self.inv_freq.device)
117
- freqs = torch.outer(t, self.inv_freq).float()
118
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
119
- del self.inv_freq # free the memory
120
- return freqs_cis
121
-
122
- def forward(
123
- self,
124
- xq: torch.Tensor,
125
- xk: torch.Tensor,
126
- start_pos: int,
127
- seqlen: int,
128
- partial: bool = False,
129
- ):
130
- xq_out, xk_out = super().forward(xq, xk, start_pos, seqlen, partial)
131
- xq_out = (xq_out * self.attention_scaling).type_as(xq)
132
- xk_out = (xk_out * self.attention_scaling).type_as(xk)
133
- return xq_out, xk_out
134
-
135
-
136
44
  def replace_quantization_scales(model, model_type):
137
- from awq.modules.act import ScaledActivation
45
+ from gptqmodel.quantization.awq.modules.act import ScaledActivation
138
46
 
139
47
  if model_type not in AWQ_SCALES_MAPPINGS:
140
48
  return model
@@ -154,437 +62,63 @@ def replace_with_awq_linear(
154
62
  model,
155
63
  modules_to_not_convert=None,
156
64
  quantization_config=None,
157
- current_key_name=None,
158
- has_been_replaced=False,
65
+ device_map: Optional[Union[str, dict]] = None,
159
66
  ) -> bool:
160
67
  """
161
- Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
162
- `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
163
- conversion has been successful or not.
164
-
165
- During the module replacement, we also infer the backend to use through the `quantization_config` object.
68
+ Public method that replaces the linear layers of the given model with awq quantized layers.
166
69
 
167
70
  Args:
168
71
  model (`torch.nn.Module`):
169
72
  The model to convert, can be any `torch.nn.Module` instance.
170
73
  quantization_config (`AwqConfig`):
171
74
  The quantization config object that contains the quantization parameters.
172
- modules_to_not_convert (`list`, *optional*):
173
- A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
75
+ modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
76
+ A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
174
77
  converted.
175
- current_key_name (`list`, *optional*):
176
- A list that contains the current key name. This is used for recursion and should not be passed by the user.
177
- has_been_replaced (`bool`, *optional*):
178
- A boolean that indicates if the conversion has been successful or not. This is used for recursion and
179
- should not be passed by the user.
180
- """
181
- if modules_to_not_convert is None:
182
- modules_to_not_convert = []
183
-
184
- backend = quantization_config.backend
185
-
186
- if not is_auto_awq_available():
187
- raise ValueError(
188
- "AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
189
- )
190
-
191
- if backend == AwqBackendPackingMethod.AUTOAWQ:
192
- if quantization_config.version == AWQLinearVersion.GEMM:
193
- from awq.modules.linear.gemm import WQLinear_GEMM
194
-
195
- target_cls = WQLinear_GEMM
196
- elif quantization_config.version == AWQLinearVersion.GEMV:
197
- from awq.modules.linear.gemv import WQLinear_GEMV
198
-
199
- target_cls = WQLinear_GEMV
200
- elif quantization_config.version == AWQLinearVersion.EXLLAMA:
201
- if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
202
- from awq.modules.linear.exllama import WQLinear_Exllama
203
-
204
- target_cls = WQLinear_Exllama
205
- elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
206
- from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
207
-
208
- target_cls = WQLinear_ExllamaV2
209
- else:
210
- raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
211
- elif quantization_config.version == AWQLinearVersion.IPEX:
212
- from awq.modules.linear.gemm_ipex import WQLinear_IPEX
213
-
214
- target_cls = WQLinear_IPEX
215
- else:
216
- raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
217
- else:
218
- from awq.quantize.qmodule import WQLinear
219
-
220
- target_cls = WQLinear
221
-
222
- for name, module in model.named_children():
223
- if current_key_name is None:
224
- current_key_name = []
225
- current_key_name.append(name)
226
-
227
- if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
228
- # Check if the current key is not in the `modules_to_not_convert`
229
- if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
230
- in_features = module.in_features
231
- out_features = module.out_features
232
-
233
- model._modules[name] = target_cls(
234
- w_bit=quantization_config.bits,
78
+ device_map (`Union[str, dict]`, *optional*, defaults to `None`):
79
+ The device map that maps the parameters to the device
80
+ """
81
+ from gptqmodel.quantization import METHOD
82
+ from gptqmodel.utils.importer import hf_select_quant_linear_v2
83
+
84
+ target_cls = hf_select_quant_linear_v2(
85
+ bits=quantization_config.bits,
86
+ group_size=quantization_config.group_size,
87
+ desc_act=False,
88
+ sym=False,
89
+ format=quantization_config.format,
90
+ backend=quantization_config.backend,
91
+ device_map=device_map,
92
+ quant_method=METHOD.AWQ,
93
+ zero_point=quantization_config.zero_point,
94
+ pack=False,
95
+ )
96
+
97
+ for module_name, module in model.named_modules():
98
+ if not should_convert_module(module_name, modules_to_not_convert):
99
+ continue
100
+ with init_empty_weights():
101
+ if isinstance(module, nn.Linear):
102
+ new_module = target_cls(
103
+ bits=quantization_config.bits,
104
+ sym=quantization_config.sym,
105
+ desc_act=quantization_config.desc_act,
235
106
  group_size=quantization_config.group_size,
236
- in_features=in_features,
237
- out_features=out_features,
107
+ in_features=module.in_features,
108
+ out_features=module.out_features,
238
109
  bias=module.bias is not None,
239
110
  dev=module.weight.device,
111
+ register_buffers=True,
240
112
  )
113
+ new_module.requires_grad_(False)
114
+ model.set_submodule(module_name, new_module)
241
115
  has_been_replaced = True
242
116
 
243
- # Force requires grad to False to avoid unexpected errors
244
- model._modules[name].requires_grad_(False)
245
- if len(list(module.children())) > 0:
246
- _, has_been_replaced = replace_with_awq_linear(
247
- module,
248
- modules_to_not_convert=modules_to_not_convert,
249
- current_key_name=current_key_name,
250
- quantization_config=quantization_config,
251
- has_been_replaced=has_been_replaced,
252
- )
253
- # Remove the last key for recursion
254
- current_key_name.pop(-1)
255
- return model, has_been_replaced
256
-
257
-
258
- def get_modules_to_fuse(model, quantization_config):
259
- """
260
- Returns the fusing mapping given the quantization config and the model
261
-
262
- Args:
263
- model (`~PreTrainedModel`):
264
- The model to fuse - note this model should have been converted into AWQ format beforehand.
265
- quantization_config (`~transformers.quantization_config.AWQConfig`):
266
- The quantization configuration to use.
267
- """
268
- if not isinstance(model, PreTrainedModel):
269
- raise TypeError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")
270
-
271
- # Always default to `quantization_config.modules_to_fuse`
272
- if quantization_config.modules_to_fuse is not None:
273
- current_fused_mapping = quantization_config.modules_to_fuse
274
- current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
275
- elif model.config.model_type in AWQ_FUSED_MAPPINGS:
276
- current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
277
-
278
- # Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
279
- config = model.config.get_text_config(decoder=True)
280
-
281
- # Handle hidden_size, num_attention_heads, num_key_value_heads, rope_parameters on our own.
282
- hidden_size = config.hidden_size
283
- num_attention_heads = config.num_attention_heads
284
- num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
285
- rope_parameters = config.rope_parameters
286
-
287
- # Fill `current_fused_mapping` with the expected values
288
- current_fused_mapping["hidden_size"] = hidden_size
289
- current_fused_mapping["num_attention_heads"] = num_attention_heads
290
- current_fused_mapping["num_key_value_heads"] = num_key_value_heads
291
- current_fused_mapping["rope_parameters"] = rope_parameters
292
- current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
293
- else:
294
- raise ValueError(
295
- "Fusing mapping not found either on the quantization config or the supported `AWQ_FUSED_MAPPINGS`. Please pass a `fused_mapping` argument"
296
- " in the `quantization_config` or raise an issue on transformers https://github.com/huggingface/transformers to add its support."
297
- )
298
- return current_fused_mapping
299
-
300
-
301
- def fuse_awq_modules(model, quantization_config):
302
- """
303
- Optionally fuse some modules in the model to speedup inference.
304
-
305
- Args:
306
- model (`~PreTrainedModel`):
307
- The model to fuse - note this model should have been converted into AWQ format beforehand.
308
- quantization_config (`Union[AwqConfig, dict]`):
309
- The quantization configuration to use.
310
- """
311
- # We need to convert it from dict in order to get an AwqConfig object
312
- # otherwise the fields `backend` etc. will not be available
313
- # https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
314
- if isinstance(quantization_config, dict):
315
- quantization_config = AwqConfig.from_dict(quantization_config)
316
- backend = quantization_config.backend
317
-
318
- modules_to_fuse = get_modules_to_fuse(model, quantization_config)
319
- modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)
320
-
321
- if backend == AwqBackendPackingMethod.AUTOAWQ:
322
- from awq.modules.fused.attn import QuantAttentionFused
323
- from awq.modules.fused.mlp import QuantFusedMLP
324
- from awq.modules.fused.norm import FasterTransformerRMSNorm
325
-
326
- # Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value
327
- old_quant_attention_fused_forward = QuantAttentionFused.forward
328
-
329
- def new_quant_attention_fused_forward(self, *args, **kwargs):
330
- attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs)
331
- return attn_output, attention_weight
332
-
333
- QuantAttentionFused.forward = new_quant_attention_fused_forward
334
- else:
335
- raise ValueError("Fusing is only supported for the AutoAWQ backend")
336
-
337
- fused_attention_modules = []
338
-
339
- for name, module in model.named_modules():
340
- if modules_to_not_convert is not None:
341
- if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
342
- continue
343
-
344
- # Replace layer norms
345
- _fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
346
-
347
- # Replace MLP layers if awq version is not ipex.
348
- if quantization_config.version != "ipex":
349
- _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
350
- else:
351
- logger.info("The IPEX version AWQ does not support fuse mlp for now.")
352
-
353
- # Replace attention layers
354
- attention_has_been_fused = _fuse_awq_attention_layers(
355
- model, module, modules_to_fuse, name, QuantAttentionFused
356
- )
357
-
358
- if attention_has_been_fused:
359
- fused_attention_modules.append(name.split(".")[0])
360
-
361
- # For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
362
- # `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
363
- # by the `AttentionMaskConverter` module.
364
- if len(fused_attention_modules) > 0:
365
- for module_name, module in model.named_modules():
366
- if any(
367
- module_name in fused_attention_modules for fused_attention_parent_module in fused_attention_modules
368
- ):
369
- if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"):
370
- module.config._attn_implementation = "custom"
371
- return model
372
-
373
-
374
- def _fuse_awq_layernorm(fuse_module_names, module, target_cls):
375
- """
376
- Fuse the LayerNorm layers into a target class using autoawq
377
-
378
- Args:
379
- fuse_module_names (`list[str]`):
380
- The list of module names to fuse
381
- module (`nn.Module`):
382
- The pytorch parent module that has layernorm modules to fuse
383
- target_cls (`~autoawq.FasterTransformerRMSNorm`):
384
- The `FasterTransformerRMSNorm` class as it only supports that class
385
- for now.
386
- """
387
- for module_name in fuse_module_names:
388
- if hasattr(module, module_name):
389
- old_module = getattr(module, module_name)
390
- module._modules[module_name] = target_cls(
391
- old_module.weight,
392
- old_module.variance_epsilon,
393
- ).to(old_module.weight.device)
394
- del old_module
395
-
396
-
397
- def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_cls):
398
- """
399
- Fuse the MLP layers into a target class using autoawq
400
-
401
- Args:
402
- model (`~PreTrainedModel`):
403
- The input pretrained model
404
- current_module_name (`str`):
405
- The current submodule name
406
- fuse_module_names (`list[str]`):
407
- The list of module names to fuse. For the MLP layers it has to be an array
408
- of length 3 that consists of the 3 MLP layers in the order (gate (dense layer post-attention) / up / down layers)
409
- module (`nn.Module`):
410
- The pytorch parent module that has layernorm modules to fuse
411
- target_cls (`~autoawq.QuantFusedMLP`):
412
- The `QuantFusedMLP` class as it only supports that class
413
- for now.
414
- """
415
- if len(fuse_module_names) == 0:
416
- return
417
-
418
- if hasattr(module, fuse_module_names[0]):
419
- gate_proj = getattr(module, fuse_module_names[0])
420
- up_proj = getattr(module, fuse_module_names[1])
421
- down_proj = getattr(module, fuse_module_names[2])
422
-
423
- previous_device = gate_proj.qweight.device
424
-
425
- # Deal also with the case model has `text_config` attribute
426
- config = model.config.get_text_config(decoder=True)
427
- hidden_act = config.hidden_act
428
- activation_fn = ACT2FN[hidden_act]
429
- new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
430
-
431
- parent_name, child_name = current_module_name.rsplit(".", 1)
432
- parent = model.get_submodule(parent_name)
433
- setattr(parent, child_name, new_module.to(previous_device))
434
-
435
- del gate_proj, up_proj, down_proj
436
-
437
-
438
- def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_name, target_cls):
439
- """
440
- Fuse the Attention layers into a target class using autoawq
441
-
442
- Args:
443
- model (`~PreTrainedModel`):
444
- The input pretrained model
445
- module (`nn.Module`):
446
- The pytorch parent module that has layernorm modules to fuse
447
- modules_to_fuse (`list[str]`):
448
- The module fusing mapping. The dictionary has to contain a field `attention` with attention module names
449
- in the correct order: q, k, v, o layer, (q_norm, k_norm) optional
450
- current_module_name (`str`):
451
- The current submodule name
452
- target_cls (`~autoawq.QuantAttentionFused`):
453
- The `QuantAttentionFused` class as it only supports that class
454
- for now.
455
- """
456
- from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
457
-
458
- module_has_been_fused = False
459
-
460
- if len(modules_to_fuse["attention"]) == 0:
461
- return module_has_been_fused
462
-
463
- if hasattr(module, modules_to_fuse["attention"][0]):
464
- # First, we pack the QKV layers together
465
- q_proj = getattr(module, modules_to_fuse["attention"][0])
466
-
467
- if isinstance(q_proj, WQLinear_GEMV):
468
- linear_target_cls = WQLinear_GEMV
469
- cat_dim = 0
470
- elif isinstance(q_proj, WQLinear_GEMM):
471
- linear_target_cls = WQLinear_GEMM
472
- cat_dim = 1
473
- elif is_ipex_available() and version.parse(importlib.metadata.version("autoawq")) > version.parse("0.2.6"):
474
- from awq.modules.linear import WQLinear_IPEX
475
-
476
- if isinstance(q_proj, WQLinear_IPEX):
477
- linear_target_cls = WQLinear_IPEX
478
- cat_dim = 1
479
- else:
480
- raise ValueError("Unsupported q_proj type: {type(q_proj)}")
481
-
482
- previous_device = q_proj.qweight.device
483
-
484
- k_proj = getattr(module, modules_to_fuse["attention"][1])
485
- v_proj = getattr(module, modules_to_fuse["attention"][2])
486
- o_proj = getattr(module, modules_to_fuse["attention"][3])
487
-
488
- # maybe there are q_norm and k_norm layers
489
- if len(modules_to_fuse["attention"]) > 4:
490
- q_norm = getattr(module, modules_to_fuse["attention"][4])
491
- k_norm = getattr(module, modules_to_fuse["attention"][5])
492
- else:
493
- q_norm = None
494
- k_norm = None
495
-
496
- bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
497
-
498
- qkv_layer = linear_target_cls(
499
- q_proj.w_bit,
500
- q_proj.group_size,
501
- q_proj.in_features,
502
- q_proj.out_features + k_proj.out_features + v_proj.out_features,
503
- q_proj.bias is not None,
504
- next(iter(module.state_dict().values())).device,
505
- )
506
-
507
- qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=cat_dim)
508
- qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=cat_dim)
509
- qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=cat_dim)
510
-
511
- if isinstance(qkv_layer, WQLinear_GEMV):
512
- qkv_layer.split_k_iters = q_proj.split_k_iters
513
-
514
- qkv_layer.bias = bias
515
-
516
- fused_attention_layer = target_cls(
517
- modules_to_fuse["hidden_size"],
518
- modules_to_fuse["num_attention_heads"],
519
- modules_to_fuse["num_key_value_heads"],
520
- qkv_layer,
521
- o_proj,
522
- previous_device,
523
- modules_to_fuse["max_seq_len"],
524
- use_alibi=modules_to_fuse["use_alibi"],
525
- # The default value in autoawq is set to 10000.0
526
- rope_theta=modules_to_fuse["rope_parameters"].get("rope_theta", 10000.0),
527
- q_norm=q_norm,
528
- k_norm=k_norm,
529
- )
530
-
531
- # Hack the rope module if not using alibi and rope_type is not default
532
- # As the default rope implementation in autoawq only supports the "default" rope type
533
- rope_type = modules_to_fuse["rope_parameters"].get("rope_type", "default")
534
- if not modules_to_fuse["use_alibi"] and rope_type != "default":
535
- fused_attention_layer.rope = AWQRoPE(
536
- rope_type,
537
- modules_to_fuse["hidden_size"] // modules_to_fuse["num_attention_heads"],
538
- modules_to_fuse["max_seq_len"],
539
- model.config.get_text_config(decoder=True),
540
- previous_device,
541
- )
542
-
543
- fused_attention_layer.is_hf_transformers = True
544
-
545
- parent_name, child_name = current_module_name.rsplit(".", 1)
546
- parent = model.get_submodule(parent_name)
547
- setattr(parent, child_name, fused_attention_layer.to(previous_device))
548
-
549
- del q_proj, k_proj, v_proj, o_proj, q_norm, k_norm
550
- module_has_been_fused = True
551
-
552
- return module_has_been_fused
553
-
554
-
555
- def post_init_awq_exllama_modules(model, exllama_config):
556
- """
557
- Runs post init for Exllama layers which performs:
558
- - Weights unpacking, reordering and repacking
559
- - Devices scratch space allocation
560
- """
561
-
562
- if exllama_config["version"] == ExllamaVersion.ONE:
563
- from awq.modules.linear.exllama import exllama_post_init
564
-
565
- model = exllama_post_init(model)
566
- elif exllama_config["version"] == ExllamaVersion.TWO:
567
- from awq.modules.linear.exllamav2 import exllamav2_post_init
568
-
569
- model = exllamav2_post_init(
570
- model,
571
- max_input_len=exllama_config["max_input_len"],
572
- max_batch_size=exllama_config["max_batch_size"],
117
+ if not has_been_replaced:
118
+ logger.warning(
119
+ "You are loading your model using eetq but no linear modules were found in your model."
120
+ " Please double check your model architecture, or submit an issue on github if you think this is"
121
+ " a bug."
573
122
  )
574
- else:
575
- raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
576
-
577
- return model
578
-
579
-
580
- def post_init_awq_ipex_modules(model):
581
- """
582
- Runs post init for IPEX layers which performs:
583
- - Weights packing, reordering and repacking
584
- """
585
-
586
- from awq.modules.linear.gemm_ipex import ipex_post_init
587
-
588
- model = ipex_post_init(model)
589
123
 
590
124
  return model