transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,13 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- from ..utils import is_accelerate_available, is_eetq_available, logging
15
+ from ..core_model_loading import ConversionOps
16
+ from ..quantizers.quantizers_utils import should_convert_module
17
+ from ..utils import is_accelerate_available, is_torch_available, logging
16
18
 
17
19
 
18
- if is_eetq_available():
19
- import eetq
20
+ if is_torch_available():
21
+ import torch
20
22
  import torch.nn as nn
21
23
 
22
24
  if is_accelerate_available():
@@ -25,91 +27,94 @@ if is_accelerate_available():
25
27
  logger = logging.get_logger(__name__)
26
28
 
27
29
 
28
- def _replace_with_eetq_linear(
29
- model,
30
- modules_to_not_convert=None,
31
- current_key_name=None,
32
- quantization_config=None,
33
- has_been_replaced=False,
34
- pre_quantized=False,
35
- ):
36
- """
37
- Private method that wraps the recursion for module replacement.
30
+ class EetqQuantize(ConversionOps):
31
+ def __init__(self, hf_quantizer):
32
+ self.hf_quantizer = hf_quantizer
38
33
 
39
- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
40
- """
41
- if current_key_name is None:
42
- current_key_name = []
43
-
44
- for name, module in model.named_children():
45
- current_key_name.append(name)
46
-
47
- if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
48
- # Check if the current key is not in the `modules_to_not_convert`
49
- current_key_name_str = ".".join(current_key_name)
50
- if not any(
51
- (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
52
- ):
53
- with init_empty_weights():
54
- in_features = module.in_features
55
- out_features = module.out_features
56
- model._modules[name] = eetq.EetqLinear(
57
- in_features, out_features, module.bias is not None, module.weight.device
58
- )
59
- if pre_quantized:
60
- model._modules[name].register_scale(module.weight.device)
61
- has_been_replaced = True
62
-
63
- # Force requires grad to False to avoid unexpected errors
64
- model._modules[name].requires_grad_(False)
65
- if len(list(module.children())) > 0:
66
- _, has_been_replaced = _replace_with_eetq_linear(
67
- module,
68
- modules_to_not_convert,
69
- current_key_name,
70
- quantization_config,
71
- has_been_replaced=has_been_replaced,
72
- pre_quantized=pre_quantized,
73
- )
74
- # Remove the last key for recursion
75
- current_key_name.pop(-1)
76
- return model, has_been_replaced
77
-
78
-
79
- def replace_with_eetq_linear(
80
- model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
81
- ):
82
- """
83
- A helper function to replace all `torch.nn.Linear` modules by `eetq.EetqLinear` modules from the `eetq`
84
- library. This will enable running your models using high performance int8 weight-only gemm kerner from
85
- FasterTransformer and TensorRT-LLM. Make sure `eetq` compiled with the correct CUDA
86
- version of your hardware is installed before running this function. EETQ shall be installed via the source
87
- 'https://github.com/NetEase-FuXi/EETQ'
34
+ def convert(
35
+ self, input_dict: dict[str, list[torch.Tensor]], full_layer_name: str | None = None, **kwargs
36
+ ) -> dict[str, torch.Tensor]:
37
+ _, value = tuple(input_dict.items())[0]
38
+ value = value[0]
39
+
40
+ value_device = value.device
41
+ int8_weight = torch.t(value).contiguous().cpu()
42
+ int8_weight, scales = eetq_kernels_hub.quant_weights(int8_weight, torch.int8, False)
43
+
44
+ int8_weight = int8_weight.to(value_device)
45
+ scales = scales.to(value_device)
46
+
47
+ return {full_layer_name: int8_weight, f"{full_layer_name}_scales": scales}
48
+
49
+
50
+ class EetqLinearMMFunction(torch.autograd.Function):
51
+ @staticmethod
52
+ def forward(ctx, x, weight, scales, bias=None):
53
+ # The forward pass can use ctx.
54
+ ctx.save_for_backward(x, weight, scales, bias)
55
+ output = eetq_kernels_hub.w8_a16_gemm(x, weight, scales)
56
+ output = output + bias if bias is not None else output
57
+ return output
58
+
59
+ @staticmethod
60
+ def backward(ctx, grad_output):
61
+ input, weight, scales, bias = ctx.saved_tensors
62
+ identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype)
63
+
64
+ # Dequantize the weight
65
+ weight = eetq_kernels_hub.w8_a16_gemm(identity, weight, scales)
88
66
 
89
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
90
- be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
91
- CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
67
+ if ctx.needs_input_grad[0]:
68
+ # 2D matrix multiplication, unsqueeze to 3D
69
+ grad_input = grad_output.squeeze(0).matmul(weight.transpose(0, 1)).unsqueeze(0)
70
+
71
+ return grad_input, None, None, None
72
+
73
+
74
+ class EetqLinear(nn.Module):
75
+ def __init__(self, in_features, out_features, dtype=torch.int8, bias=False):
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.empty((in_features, out_features), dtype=dtype), requires_grad=False)
78
+ self.weight_scales = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
79
+ if bias:
80
+ self.bias = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
81
+ else:
82
+ self.bias = None
83
+
84
+ def forward(self, input):
85
+ output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
86
+ return output
87
+
88
+
89
+ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False):
90
+ """
91
+ A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules.
92
92
 
93
93
  Parameters:
94
94
  model (`torch.nn.Module`):
95
95
  Input model or `torch.nn.Module` as the function is run recursively.
96
- modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
96
+ modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
97
97
  Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
98
98
  for numerical stability reasons.
99
- current_key_name (`list[`str`]`, *optional*):
100
- An array to track the current key of the recursion. This is used to check whether the current key (part of
101
- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
102
- `disk`).
103
99
  """
104
-
105
- modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
106
-
107
- if quantization_config.modules_to_not_convert is not None:
108
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
109
- modules_to_not_convert = list(set(modules_to_not_convert))
110
- model, has_been_replaced = _replace_with_eetq_linear(
111
- model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
112
- )
100
+ from kernels import get_kernel
101
+
102
+ global eetq_kernels_hub
103
+ eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
104
+
105
+ has_been_replaced = False
106
+ # we need this to correctly materialize the weights during quantization
107
+ module_kwargs = {} if pre_quantized else {"dtype": None}
108
+ for module_name, module in model.named_modules():
109
+ if not should_convert_module(module_name, modules_to_not_convert):
110
+ continue
111
+ with init_empty_weights():
112
+ if isinstance(module, nn.Linear):
113
+ new_module = EetqLinear(
114
+ module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
115
+ )
116
+ model.set_submodule(module_name, new_module)
117
+ has_been_replaced = True
113
118
 
114
119
  if not has_been_replaced:
115
120
  logger.warning(
@@ -12,8 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from functools import lru_cache
16
+ from typing import Optional
17
+
15
18
  from ..activations import ACT2FN
16
- from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
19
+ from ..core_model_loading import ConversionOps
20
+ from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
21
+ from ..utils import (
22
+ is_accelerate_available,
23
+ is_fbgemm_gpu_available,
24
+ is_torch_available,
25
+ is_torch_xpu_available,
26
+ logging,
27
+ )
17
28
 
18
29
 
19
30
  if is_torch_available():
@@ -23,24 +34,83 @@ if is_torch_available():
23
34
  if is_accelerate_available():
24
35
  from accelerate import init_empty_weights
25
36
 
26
- if is_fbgemm_gpu_available():
37
+ _is_torch_xpu_available = is_torch_xpu_available()
38
+
39
+ if is_fbgemm_gpu_available() and not _is_torch_xpu_available:
27
40
  import fbgemm_gpu.experimental.gen_ai # noqa: F401
28
41
 
29
42
  logger = logging.get_logger(__name__)
30
43
 
31
44
 
45
+ class FbgemmFp8Quantize(ConversionOps):
46
+ def __init__(self, hf_quantizer):
47
+ self.hf_quantizer = hf_quantizer
48
+
49
+ def convert(
50
+ self,
51
+ input_dict: dict[str, torch.Tensor | list[torch.Tensor]],
52
+ model: Optional[torch.nn.Module] = None,
53
+ **kwargs,
54
+ ) -> dict[str, torch.Tensor]:
55
+ target_key, value = tuple(input_dict.items())[0]
56
+ value = value[0]
57
+
58
+ from ..integrations import FbgemmFp8Llama4TextExperts
59
+
60
+ module, tensor_name = get_module_from_name(model, target_key)
61
+
62
+ if isinstance(module, FbgemmFp8Llama4TextExperts):
63
+ if tensor_name == "gate_up_proj":
64
+ # Process each expert separately
65
+ # Transpose the second and third dimension
66
+ transposed_param = value.transpose(1, 2)
67
+
68
+ # Reshape to 2D for quantization
69
+ original_shape = transposed_param.shape
70
+ flattened_param = transposed_param.reshape(-1, original_shape[-1])
71
+
72
+ # Quantize using per row instead of per column
73
+ new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
74
+
75
+ # Reshape back to original dimensions
76
+ new_value = new_value_flat.reshape(original_shape)
77
+ new_value = new_value.transpose(1, 2)
78
+ weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
79
+ elif tensor_name == "down_proj":
80
+ # Process each expert separately
81
+ # Transpose the weights for proper quantization
82
+ transposed_param = value.transpose(1, 2)
83
+
84
+ # Reshape to 2D for quantization
85
+ original_shape = transposed_param.shape
86
+ flattened_param = transposed_param.reshape(-1, original_shape[-1])
87
+
88
+ # Quantize using per column
89
+ new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
90
+
91
+ # Reshape back to original dimensions
92
+ new_value = new_value_flat.reshape(original_shape)
93
+ new_value = new_value.transpose(1, 2)
94
+ weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
95
+ else:
96
+ new_value, weight_scale = quantize_fp8_per_row(value)
97
+ weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))
98
+
99
+ return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
100
+
101
+
32
102
  class FbgemmFp8Linear(torch.nn.Linear):
33
- def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
103
+ def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn):
34
104
  super().__init__(in_features, out_features, bias)
35
105
  self.in_features = in_features
36
106
  self.out_features = out_features
37
107
 
38
- self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
39
- self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
108
+ self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype))
109
+ self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32))
40
110
  self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
41
111
 
42
112
  if bias:
43
- self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
113
+ self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32))
44
114
  else:
45
115
  self.bias = None
46
116
 
@@ -49,18 +119,26 @@ class FbgemmFp8Linear(torch.nn.Linear):
49
119
  output_shape = (*x.shape[:-1], -1)
50
120
  # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
51
121
  # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
52
- x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
53
- x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
54
- )
122
+ x_quantized, x_scale = quantize_fp8_per_row(x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub)
55
123
  # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
56
124
  # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
57
125
 
58
126
  # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
59
127
  weight_scale_float32 = self.weight_scale.to(torch.float32)
60
- output = torch.ops.fbgemm.f8f8bf16_rowwise(
61
- x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
62
- )
63
- output = output + self.bias if self.bias is not None else output
128
+ if _is_torch_xpu_available:
129
+ output = torch._scaled_mm(
130
+ x_quantized,
131
+ self.weight.t(),
132
+ scale_a=x_scale.unsqueeze(-1),
133
+ scale_b=weight_scale_float32.t(),
134
+ out_dtype=x.dtype,
135
+ bias=self.bias,
136
+ )
137
+ else:
138
+ output = torch.ops.fbgemm.f8f8bf16_rowwise(
139
+ x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
140
+ )
141
+ output = output + self.bias if self.bias is not None else output
64
142
  # Hacky for now, we have the output to the device of x
65
143
  output = output.to(x.device)
66
144
  output = output.reshape(output_shape)
@@ -112,168 +190,136 @@ class FbgemmFp8Llama4TextExperts(nn.Module):
112
190
  expert_hidden = hidden_states[i]
113
191
  expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
114
192
  # Quantize for this expert
115
- expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
193
+ expert_quantized, expert_scale = quantize_fp8_per_row(
116
194
  expert_hidden_reshaped, num_tokens, self.input_scale_ub
117
195
  )
118
196
  sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
119
197
  gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
198
+ if _is_torch_xpu_available:
199
+ gate = torch._scaled_mm(
200
+ expert_quantized,
201
+ self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous().t(),
202
+ scale_a=expert_scale.unsqueeze(-1),
203
+ scale_b=gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous().t(),
204
+ out_dtype=hidden_states.dtype,
205
+ )
206
+ up = torch._scaled_mm(
207
+ expert_quantized,
208
+ self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous().t(),
209
+ scale_a=expert_scale.unsqueeze(-1),
210
+ scale_b=gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous().t(),
211
+ out_dtype=hidden_states.dtype,
212
+ )
213
+ else:
214
+ gate = torch.ops.fbgemm.f8f8bf16_rowwise(
215
+ expert_quantized,
216
+ self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
217
+ expert_scale,
218
+ gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
219
+ use_fast_accum=True,
220
+ )
120
221
 
121
- gate = torch.ops.fbgemm.f8f8bf16_rowwise(
122
- expert_quantized,
123
- self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
124
- expert_scale,
125
- gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
126
- use_fast_accum=True,
127
- )
128
-
129
- up = torch.ops.fbgemm.f8f8bf16_rowwise(
130
- expert_quantized,
131
- self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
132
- expert_scale,
133
- gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
134
- use_fast_accum=True,
135
- )
222
+ up = torch.ops.fbgemm.f8f8bf16_rowwise(
223
+ expert_quantized,
224
+ self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
225
+ expert_scale,
226
+ gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
227
+ use_fast_accum=True,
228
+ )
136
229
 
137
230
  activated = up * self.act_fn(gate)
138
231
 
139
- activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
140
- activated, num_tokens, self.input_scale_ub
141
- )
232
+ activated_quantized, activated_scale = quantize_fp8_per_row(activated, num_tokens, self.input_scale_ub)
142
233
 
143
234
  down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
144
- expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
145
- activated_quantized,
146
- self.down_proj[i].transpose(0, 1).contiguous(),
147
- activated_scale,
148
- down_proj_scale_float32[i].view(-1, 1).contiguous(),
149
- use_fast_accum=True,
150
- )
235
+ if _is_torch_xpu_available:
236
+ expert_output = torch._scaled_mm(
237
+ activated_quantized,
238
+ self.down_proj[i].transpose(0, 1).contiguous(),
239
+ scale_a=activated_scale.unsqueeze(-1),
240
+ scale_b=down_proj_scale_float32[i].view(-1, 1).contiguous().t(),
241
+ out_dtype=hidden_states.dtype,
242
+ )
243
+ else:
244
+ expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
245
+ activated_quantized,
246
+ self.down_proj[i].transpose(0, 1).contiguous(),
247
+ activated_scale,
248
+ down_proj_scale_float32[i].view(-1, 1).contiguous(),
249
+ use_fast_accum=True,
250
+ )
151
251
 
152
252
  next_states[i] = expert_output
153
253
  next_states = next_states.to(hidden_states.device)
154
254
  return next_states.view(-1, self.hidden_size)
155
255
 
156
256
 
157
- def _replace_with_fbgemm_fp8_linear(
158
- model,
159
- modules_to_not_convert=None,
160
- current_key_name=None,
161
- quantization_config=None,
162
- has_been_replaced=False,
163
- pre_quantized=False,
164
- config=None,
165
- tp_plan=None,
166
- ):
167
- """
168
- Private method that wraps the recursion for module replacement.
169
-
170
- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
171
- """
172
-
173
- import re
174
-
175
- if current_key_name is None:
176
- current_key_name = []
177
-
178
- for name, module in model.named_children():
179
- current_key_name.append(name)
180
-
181
- if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
182
- # Check if the current key is not in the `modules_to_not_convert`
183
- current_key_name_str = ".".join(current_key_name)
184
- if not any(
185
- (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
186
- ):
187
- with init_empty_weights(include_buffers=True):
188
- in_features = module.in_features
189
- out_features = module.out_features
190
- model._modules[name] = FbgemmFp8Linear(
191
- in_features,
192
- out_features,
193
- module.bias is not None,
194
- )
195
- has_been_replaced = True
196
-
197
- # Force requires grad to False to avoid unexpected errors
198
- model._modules[name].requires_grad_(False)
199
- # set non persistent buffer outside of init_empty_weights
200
- model._modules[name].input_scale_ub = torch.tensor(
201
- [quantization_config.activation_scale_ub],
202
- dtype=torch.float,
203
- )
204
- if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
205
- current_key_name_str = ".".join(current_key_name)
206
- if not any(
207
- (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
208
- ):
209
- with init_empty_weights(include_buffers=True):
210
- tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
211
- model._modules[name] = FbgemmFp8Llama4TextExperts(
212
- config.text_config,
213
- )
214
- model._modules[name].input_scale_ub = torch.tensor(
215
- [quantization_config.activation_scale_ub], dtype=torch.float
216
- )
257
+ @lru_cache(maxsize=1)
258
+ def get_quantize_fp8_per_row():
259
+ if _is_torch_xpu_available:
260
+ from kernels import get_kernel
217
261
 
218
- if len(list(module.children())) > 0:
219
- _, has_been_replaced = _replace_with_fbgemm_fp8_linear(
220
- module,
221
- modules_to_not_convert,
222
- current_key_name,
223
- quantization_config,
224
- has_been_replaced=has_been_replaced,
225
- pre_quantized=pre_quantized,
226
- config=config,
227
- tp_plan=tp_plan,
228
- )
229
- # Remove the last key for recursion
230
- current_key_name.pop(-1)
231
- return model, has_been_replaced
262
+ return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
263
+ return torch.ops.fbgemm.quantize_fp8_per_row
232
264
 
233
265
 
234
266
  def replace_with_fbgemm_fp8_linear(
235
- model,
236
- modules_to_not_convert=None,
237
- current_key_name=None,
238
- quantization_config=None,
239
- pre_quantized=False,
240
- config=None,
241
- tp_plan=None,
267
+ model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False, tp_plan=None
242
268
  ):
243
269
  """
244
270
  A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
245
271
  This will enable running your models using high performance fp8 kernel from FBGEMM library.
246
272
 
247
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
248
- be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
249
- CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
250
-
251
273
  Parameters:
252
274
  model (`torch.nn.Module`):
253
275
  Input model or `torch.nn.Module` as the function is run recursively.
254
- modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
255
- Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
256
- for numerical stability reasons.
257
- current_key_name (`list[`str`]`, *optional*):
258
- An array to track the current key of the recursion. This is used to check whether the current key (part of
259
- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
260
- `disk`).
276
+ modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
277
+ Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
278
+ quantization_config (`FbgemmFp8Config`):
279
+ The quantization config object that contains the quantization parameters.
280
+ pre_quantized (`book`, defaults to `False`):
281
+ Whether the model is pre-quantized or not
261
282
  """
283
+ global quantize_fp8_per_row
284
+ quantize_fp8_per_row = get_quantize_fp8_per_row()
285
+
286
+ has_been_replaced = False
287
+ module_kwargs = {} if pre_quantized else {"dtype": None}
288
+
289
+ for module_name, module in model.named_modules():
290
+ if not should_convert_module(module_name, modules_to_not_convert):
291
+ continue
292
+
293
+ new_module = None
294
+ with init_empty_weights(include_buffers=True):
295
+ if module.__class__.__name__ == "Llama4TextExperts":
296
+ # TODO: make sure tp works later
297
+ # if tp_plan is not None:
298
+ # tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
299
+ # tp_plan[tp_key] = None
300
+ text_config = getattr(model.config, "text_config", model.config)
301
+ new_module = FbgemmFp8Llama4TextExperts(text_config or model.config)
302
+ elif isinstance(module, nn.Linear):
303
+ new_module = FbgemmFp8Linear(
304
+ module.in_features,
305
+ module.out_features,
306
+ module.bias is not None,
307
+ **module_kwargs,
308
+ )
309
+ new_module.requires_grad_(False)
310
+
311
+ if new_module is None:
312
+ continue
313
+
314
+ if hasattr(new_module, "input_scale_ub"):
315
+ new_module.input_scale_ub = torch.tensor(
316
+ [quantization_config.activation_scale_ub],
317
+ dtype=torch.float,
318
+ )
319
+
320
+ model.set_submodule(module_name, new_module)
321
+ has_been_replaced = True
262
322
 
263
- modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
264
-
265
- if quantization_config.modules_to_not_convert is not None:
266
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
267
- modules_to_not_convert = list(set(modules_to_not_convert))
268
- model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
269
- model,
270
- modules_to_not_convert,
271
- current_key_name,
272
- quantization_config,
273
- pre_quantized=pre_quantized,
274
- config=config,
275
- tp_plan=tp_plan,
276
- )
277
323
  if not has_been_replaced:
278
324
  logger.warning(
279
325
  "You are loading your model using FP8 quantization but no linear modules were found in your model."