transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,10 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import re
17
- from typing import Optional
18
-
19
16
  from ..core_model_loading import ConversionOps
17
+ from ..quantizers.quantizers_utils import should_convert_module
20
18
  from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
21
19
 
22
20
 
@@ -158,6 +156,79 @@ def _w8a8_block_fp8_matmul(
158
156
  tl.store(c_ptrs, c, mask=c_mask)
159
157
 
160
158
 
159
+ @triton.jit
160
+ def _w8a8_block_fp8_matmul_per_tensor(
161
+ # Pointers to inputs and output
162
+ A,
163
+ B,
164
+ C,
165
+ As,
166
+ Bs,
167
+ # Shape for matmul
168
+ M,
169
+ N,
170
+ K,
171
+ # Block size for block-wise quantization
172
+ group_n,
173
+ group_k,
174
+ # Stride for inputs and output
175
+ stride_am,
176
+ stride_ak,
177
+ stride_bk,
178
+ stride_bn,
179
+ stride_cm,
180
+ stride_cn,
181
+ # Meta-parameters
182
+ BLOCK_SIZE_M: tl.constexpr,
183
+ BLOCK_SIZE_N: tl.constexpr,
184
+ BLOCK_SIZE_K: tl.constexpr,
185
+ GROUP_SIZE_M: tl.constexpr,
186
+ ):
187
+ """Triton-accelerated function used to perform linear operations (dot
188
+ product) on input tensors `A` and `B` with per-tensor quantization, and
189
+ store the result in output tensor `C`.
190
+ """
191
+
192
+ pid = tl.program_id(axis=0)
193
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
194
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
195
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
196
+ group_id = pid // num_pid_in_group
197
+ first_pid_m = group_id * GROUP_SIZE_M
198
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
199
+ pid_m = first_pid_m + (pid % group_size_m)
200
+ pid_n = (pid % num_pid_in_group) // group_size_m
201
+
202
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
203
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
204
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
205
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
206
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
207
+ scale_a = tl.load(As)
208
+ scale_b = tl.load(Bs)
209
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
210
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
211
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
212
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
213
+
214
+ accumulator += tl.dot(a, b) * scale_a * scale_b
215
+ a_ptrs += BLOCK_SIZE_K * stride_ak
216
+ b_ptrs += BLOCK_SIZE_K * stride_bk
217
+
218
+ if C.dtype.element_ty == tl.bfloat16:
219
+ c = accumulator.to(tl.bfloat16)
220
+ elif C.dtype.element_ty == tl.float16:
221
+ c = accumulator.to(tl.float16)
222
+ else:
223
+ c = accumulator.to(tl.float32)
224
+
225
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
226
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
227
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
228
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
229
+ tl.store(c_ptrs, c, mask=c_mask)
230
+
231
+
161
232
  def w8a8_block_fp8_matmul_triton(
162
233
  A: torch.Tensor,
163
234
  B: torch.Tensor,
@@ -181,19 +252,31 @@ def w8a8_block_fp8_matmul_triton(
181
252
  Returns:
182
253
  torch.Tensor: The result of matmul.
183
254
  """
184
- assert len(block_size) == 2
185
- block_n, block_k = block_size[0], block_size[1]
255
+ if block_size is None:
256
+ block_n, block_k = 128, 128
257
+ else:
258
+ assert len(block_size) == 2
259
+ block_n, block_k = block_size[0], block_size[1]
260
+
261
+ # if we have per-tensor quantization, we use 128x128 block size for tiled matmul multiplication
262
+ if block_n == B.shape[-2] and block_k == B.shape[-1]:
263
+ block_n = 128
264
+ block_k = 128
186
265
 
187
266
  assert A.shape[-1] == B.shape[-1]
188
267
 
189
- assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
190
- assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
268
+ if As.numel() != 1:
269
+ assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
270
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
271
+
191
272
  M = A.numel() // A.shape[-1]
192
273
 
193
- assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
194
274
  N, K = B.shape
195
- assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}"
196
- assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}"
275
+ assert B.ndim == 2 and B.is_contiguous()
276
+ if Bs.numel() != 1:
277
+ assert Bs.ndim == 2
278
+ assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}"
279
+ assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}"
197
280
 
198
281
  C_shape = A.shape[:-1] + (N,)
199
282
  C = A.new_empty(C_shape, dtype=output_dtype)
@@ -209,32 +292,56 @@ def w8a8_block_fp8_matmul_triton(
209
292
  def grid(META):
210
293
  return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
211
294
 
212
- _w8a8_block_fp8_matmul[grid](
213
- A,
214
- B,
215
- C,
216
- As,
217
- Bs,
218
- M,
219
- N,
220
- K,
221
- block_n,
222
- block_k,
223
- A.stride(-2),
224
- A.stride(-1),
225
- B.stride(1),
226
- B.stride(0),
227
- C.stride(-2),
228
- C.stride(-1),
229
- As.stride(-2),
230
- As.stride(-1),
231
- Bs.stride(1),
232
- Bs.stride(0),
233
- BLOCK_SIZE_M=BLOCK_SIZE_M,
234
- BLOCK_SIZE_N=BLOCK_SIZE_N,
235
- BLOCK_SIZE_K=BLOCK_SIZE_K,
236
- GROUP_SIZE_M=8,
237
- )
295
+ if As.numel() == 1 and Bs.numel() == 1:
296
+ _w8a8_block_fp8_matmul_per_tensor[grid](
297
+ A,
298
+ B,
299
+ C,
300
+ As,
301
+ Bs,
302
+ M,
303
+ N,
304
+ K,
305
+ block_n,
306
+ block_k,
307
+ A.stride(-2),
308
+ A.stride(-1),
309
+ B.stride(1),
310
+ B.stride(0),
311
+ C.stride(-2),
312
+ C.stride(-1),
313
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
314
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
315
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
316
+ GROUP_SIZE_M=8,
317
+ )
318
+ else:
319
+ _w8a8_block_fp8_matmul[grid](
320
+ A,
321
+ B,
322
+ C,
323
+ As,
324
+ Bs,
325
+ M,
326
+ N,
327
+ K,
328
+ block_n,
329
+ block_k,
330
+ A.stride(-2),
331
+ A.stride(-1),
332
+ B.stride(1),
333
+ B.stride(0),
334
+ C.stride(-2),
335
+ C.stride(-1),
336
+ As.stride(-2),
337
+ As.stride(-1),
338
+ Bs.stride(1),
339
+ Bs.stride(0),
340
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
341
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
342
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
343
+ GROUP_SIZE_M=8,
344
+ )
238
345
 
239
346
  return C
240
347
 
@@ -307,44 +414,34 @@ def w8a8_block_fp8_matmul_compile(
307
414
 
308
415
 
309
416
  class FP8Linear(nn.Linear):
310
- dtype = torch.float8_e4m3fn
311
-
312
417
  def __init__(
313
418
  self,
314
419
  in_features: int,
315
420
  out_features: int,
316
421
  bias: bool = False,
317
- dtype=None,
422
+ dtype=torch.float8_e4m3fn,
318
423
  block_size: tuple[int, int] | None = None,
319
- device=None,
320
424
  activation_scheme="dynamic",
321
425
  ):
322
426
  super().__init__(in_features, out_features)
323
- self.in_features = in_features
324
- self.out_features = out_features
325
427
 
326
- if block_size is not None:
327
- self.block_size = block_size
328
- else:
329
- self.block_size = (out_features, in_features)
428
+ # If block size is None, it means that we are doing per-tensor quantization
429
+ self.block_size = block_size
430
+ self.activation_scheme = activation_scheme
330
431
 
331
- self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
432
+ self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
332
433
 
333
- if self.weight.element_size() == 1:
434
+ if self.block_size is None:
435
+ self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
436
+ else:
334
437
  scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
335
438
  scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
336
- if scale_out_features * scale_in_features == 1:
337
- self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
338
- else:
339
- self.weight_scale_inv = nn.Parameter(
340
- torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
341
- )
342
- else:
343
- self.register_parameter("weight_scale_inv", None)
344
- self.activation_scheme = activation_scheme
439
+ self.weight_scale_inv = nn.Parameter(
440
+ torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
441
+ )
345
442
 
346
443
  if self.activation_scheme == "static":
347
- self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
444
+ self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
348
445
 
349
446
  if bias:
350
447
  self.bias = nn.Parameter(torch.empty(self.out_features))
@@ -368,30 +465,28 @@ class FP8Linear(nn.Linear):
368
465
  if self.activation_scheme == "dynamic":
369
466
  qinput, scale = act_quant(input, self.block_size[1])
370
467
  elif self.activation_scheme == "static":
371
- scale = self.activation_scale
372
- qinput = (input / scale).to(torch.float8_e4m3fn)
468
+ scale = self.activation_scale.to(torch.float32)
469
+ qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(torch.float8_e4m3fn)
470
+
373
471
  else:
374
472
  raise NotImplementedError("Not supported")
375
- # TODO: fix this later to use the triton kernel
376
- if self.activation_scheme == "static":
377
- output = F.linear(qinput.to(torch.bfloat16), weight.to(torch.bfloat16), None) * scale_inv * scale
378
- output = output.to(input.dtype)
379
- else:
380
- output = w8a8_block_fp8_matmul_triton(
381
- qinput,
382
- weight,
383
- scale,
384
- scale_inv,
385
- self.block_size,
386
- output_dtype=input.dtype,
387
- )
473
+
474
+ output = w8a8_block_fp8_matmul_triton(
475
+ qinput,
476
+ weight,
477
+ scale,
478
+ scale_inv,
479
+ self.block_size,
480
+ output_dtype=input.dtype,
481
+ )
388
482
 
389
483
  # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
390
484
  # preceding operations are ready before proceeding
391
485
  torch_accelerator_module.synchronize()
392
486
  if self.bias is not None:
393
487
  output = output + self.bias
394
- output = torch.nan_to_num(output, nan=0.0)
488
+
489
+ # output = torch.nan_to_num(output, nan=0.0)
395
490
  return output.to(dtype=input.dtype)
396
491
 
397
492
 
@@ -400,9 +495,7 @@ def _ceil_div(a, b):
400
495
 
401
496
 
402
497
  class FP8Expert(nn.Module):
403
- dtype = torch.float8_e4m3fn
404
-
405
- def __init__(self, config, block_size, device):
498
+ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
406
499
  super().__init__()
407
500
 
408
501
  from ..activations import ACT2FN
@@ -415,34 +508,24 @@ class FP8Expert(nn.Module):
415
508
  Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
416
509
  Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
417
510
 
418
- self.gate_up_proj = nn.Parameter(
419
- torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
420
- )
421
- self.down_proj = nn.Parameter(
422
- torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
423
- )
511
+ self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype))
512
+ self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype))
424
513
 
425
- # Create inverse scale tiles only when using 1-byte types (fp8)
426
- if self.gate_up_proj.element_size() == 1:
427
- bo, bi = self.block_size
514
+ bo, bi = self.block_size
428
515
 
429
- # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
430
- gu_scale_o = _ceil_div(Wg_out, bo)
431
- gu_scale_i = _ceil_div(Wg_in, bi)
432
- self.gate_up_proj_scale_inv = nn.Parameter(
433
- torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
434
- )
516
+ # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
517
+ gu_scale_o = _ceil_div(Wg_out, bo)
518
+ gu_scale_i = _ceil_div(Wg_in, bi)
519
+ self.gate_up_proj_scale_inv = nn.Parameter(
520
+ torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32)
521
+ )
435
522
 
436
- # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
437
- dp_scale_o = _ceil_div(Wd_out, bo)
438
- dp_scale_i = _ceil_div(Wd_in, bi)
439
- self.down_proj_scale_inv = nn.Parameter(
440
- torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
441
- )
442
- else:
443
- # Match FP8Linear behavior when not using 1-byte weights
444
- self.register_parameter("gate_up_proj_scale_inv", None)
445
- self.register_parameter("down_proj_scale_inv", None)
523
+ # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
524
+ dp_scale_o = _ceil_div(Wd_out, bo)
525
+ dp_scale_i = _ceil_div(Wd_in, bi)
526
+ self.down_proj_scale_inv = nn.Parameter(
527
+ torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32)
528
+ )
446
529
 
447
530
  # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
448
531
  self.register_parameter("gate_up_bias", None)
@@ -508,90 +591,56 @@ class FP8Expert(nn.Module):
508
591
  return output.to(dtype=input.dtype)
509
592
 
510
593
 
511
- # TODO: we do need this.... but not recursive...
512
- def _replace_with_fp8_linear(
513
- model,
514
- tp_plan=None,
515
- modules_to_not_convert=None,
516
- current_key_name=None,
517
- quantization_config=None,
518
- has_been_replaced=False,
519
- ):
520
- iterator = list(model.named_parameters()).copy()
521
- for name, empty_tensor in iterator:
522
- current_key_name = name
523
- name = name.rsplit(".", 1)[0] if "." in name else name
524
- module = model.get_submodule(name)
525
-
526
- current_key_name_str = re.sub(r"\d+", "*", current_key_name)
527
- if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
528
- with init_empty_weights():
529
- if (
530
- "gate_up_proj" in current_key_name
531
- or "down_proj" in current_key_name
532
- and "experts" in current_key_name
533
- ): # Experts!
534
- in_features = empty_tensor.size(-2)
535
- out_features = empty_tensor.size(-1)
536
- model.set_submodule(
537
- name,
538
- FP8Expert(
539
- config=model.config,
540
- block_size=quantization_config.weight_block_size,
541
- device=empty_tensor.device,
542
- ),
543
- )
544
-
545
- elif isinstance(module, nn.Linear):
546
- in_features = module.in_features
547
- out_features = module.out_features
548
- model.set_submodule(
549
- name,
550
- FP8Linear(
551
- in_features=in_features,
552
- out_features=out_features,
553
- bias=module.bias is not None,
554
- device=module.weight.device,
555
- dtype=module.weight.dtype,
556
- activation_scheme=quantization_config.activation_scheme,
557
- block_size=quantization_config.weight_block_size,
558
- ),
559
- )
560
- has_been_replaced = True
561
- # when changing a layer the TP PLAN for that layer should be updated. TODO
562
-
563
- return model, has_been_replaced
564
-
565
-
566
594
  def replace_with_fp8_linear(
567
- model,
568
- modules_to_not_convert=None,
569
- quantization_config=None,
595
+ model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
570
596
  ):
571
- """Helper function to replace model layers with FP8 versions."""
597
+ """
598
+ A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.
599
+
600
+ Parameters:
601
+ model (`torch.nn.Module`):
602
+ Input model or `torch.nn.Module` as the function is run recursively.
603
+ modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
604
+ Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
605
+ quantization_config (`FbgemmFp8Config`):
606
+ The quantization config object that contains the quantization parameters.
607
+ pre_quantized (`book`, defaults to `False`):
608
+ Whether the model is pre-quantized or not
609
+ """
610
+
572
611
  if quantization_config.dequantize:
573
612
  return model
574
613
 
575
- if modules_to_not_convert is None:
576
- modules_to_not_convert = []
577
- modules_to_not_convert += ["lm_head"]
578
-
579
- if quantization_config.modules_to_not_convert is not None:
580
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
581
- modules_to_not_convert = list(set(modules_to_not_convert))
582
- model, has_been_replaced = _replace_with_fp8_linear(
583
- model,
584
- tp_plan=model._tp_plan,
585
- modules_to_not_convert=modules_to_not_convert,
586
- quantization_config=quantization_config,
587
- )
614
+ has_been_replaced = False
615
+ for module_name, module in model.named_modules():
616
+ if not should_convert_module(module_name, modules_to_not_convert):
617
+ continue
618
+ # we need this to correctly materialize the weights during quantization
619
+ module_kwargs = {} if pre_quantized else {"dtype": None}
620
+ new_module = None
621
+ with init_empty_weights():
622
+ if module_name.endswith(".experts"):
623
+ new_module = FP8Expert(
624
+ config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
625
+ )
626
+ elif isinstance(module, nn.Linear):
627
+ new_module = FP8Linear(
628
+ in_features=module.in_features,
629
+ out_features=module.out_features,
630
+ bias=module.bias is not None,
631
+ activation_scheme=quantization_config.activation_scheme,
632
+ block_size=quantization_config.weight_block_size,
633
+ **module_kwargs,
634
+ )
635
+ if new_module is not None:
636
+ model.set_submodule(module_name, new_module)
637
+ has_been_replaced = True
588
638
 
589
639
  if not has_been_replaced:
590
640
  logger.warning(
591
641
  "You are loading your model using fp8 but no linear modules were found in your model."
592
642
  " Please double check your model architecture."
593
643
  )
594
-
595
644
  return model
596
645
 
597
646
 
@@ -606,7 +655,7 @@ class Fp8Quantize(ConversionOps):
606
655
  def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
607
656
  # Unpack single key/value (value may be wrapped in a list)
608
657
  target_keys, value = tuple(input_dict.items())[0]
609
- value = value[0] if isinstance(value, list) else value
658
+ value = value[0]
610
659
 
611
660
  # Resolve block size (support dict-like or attr-like quant_config)
612
661
  block_size = None
@@ -681,36 +730,30 @@ class Fp8Dequantize(ConversionOps):
681
730
  def convert(
682
731
  self,
683
732
  input_dict: dict[str, torch.Tensor],
684
- model: Optional[torch.nn.Module] = None,
685
733
  full_layer_name: str | None = None,
686
- missing_keys=None,
687
734
  **kwargs,
688
735
  ) -> dict[str, torch.Tensor]:
689
- if len(input_dict) != 2:
690
- # in case of no scales, the weights are not quantized, so we return the weights as is
691
- return {
692
- full_layer_name: input_dict["weight$"][0]
693
- if isinstance(input_dict["weight$"], list)
694
- else input_dict["weight$"]
695
- }
696
- quantized = input_dict["weight$"][0] if isinstance(input_dict["weight$"], list) else input_dict["weight$"]
697
- scales = (
698
- input_dict["weight_scale_inv"][0]
699
- if isinstance(input_dict["weight_scale_inv"], list)
700
- else input_dict["weight_scale_inv"]
701
- )
736
+ if len(input_dict) < 2:
737
+ # case where we only got weights, need to check for "weight$"
738
+ return {full_layer_name: input_dict["weight$"]}
739
+
740
+ quantized = input_dict["weight$"][0]
741
+ scales = input_dict["weight_scale_inv"][0]
702
742
 
703
743
  rows, cols = quantized.shape[-2:]
704
744
  block_size = self.hf_quantizer.quantization_config.weight_block_size
745
+ if block_size is None:
746
+ block_size = (quantized.shape[-2], quantized.shape[-1])
705
747
 
706
748
  block_m, block_n = block_size
749
+
707
750
  if rows % block_m != 0 or cols % block_n != 0:
708
751
  raise ValueError(
709
752
  f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
710
753
  )
711
-
754
+ quantized = quantized.to(scales.dtype)
712
755
  reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
713
- expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
756
+ expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
714
757
  expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
715
758
  dequantized = reshaped * expanded_scales
716
759
 
@@ -13,6 +13,10 @@
13
13
  # limitations under the License.
14
14
  "FP-Quant integration file"
15
15
 
16
+ from typing import Optional
17
+
18
+ import torch
19
+
16
20
  from ..utils import (
17
21
  is_fp_quant_available,
18
22
  )
@@ -24,6 +28,94 @@ if is_fp_quant_available():
24
28
 
25
29
  from transformers.utils.quantization_config import FPQuantConfig
26
30
 
31
+ from ..core_model_loading import ConversionOps
32
+ from ..quantizers.quantizers_utils import get_module_from_name
33
+
34
+
35
+ class FpQuantQuantize(ConversionOps):
36
+ def __init__(self, hf_quantizer):
37
+ self.hf_quantizer = hf_quantizer
38
+
39
+ def convert(
40
+ self,
41
+ input_dict: torch.Tensor,
42
+ model: Optional[torch.nn.Module] = None,
43
+ missing_keys: Optional[list[str]] = None,
44
+ **kwargs,
45
+ ) -> dict[str, torch.Tensor]:
46
+ target_key, value = tuple(input_dict.items())[0]
47
+ value = value[0]
48
+ # Loading master weights or an unquantized checkpoint
49
+ weight = torch.nn.Parameter(value)
50
+ module, _ = get_module_from_name(model, target_key)
51
+ module.weight = weight
52
+
53
+ # Let pre-forward handle the quantization and set None where necessary
54
+ # This operation will quantize the weights internally
55
+ with torch.cuda.device(value.device):
56
+ module.pre_forward()
57
+
58
+ prefix_target_key = target_key.rsplit(".", 1)[0]
59
+
60
+ # keys are set inside the module.pre_forward() method, we don't need remove them from the missing keys list
61
+ missing_keys.discard(target_key)
62
+ missing_keys.discard(f"{prefix_target_key}.backward_hadamard_matrix")
63
+ missing_keys.discard(f"{prefix_target_key}.forward_hadamard_matrix")
64
+ missing_keys.discard(f"{prefix_target_key}.act_global_scale")
65
+ missing_keys.discard(f"{prefix_target_key}.weight_global_scale")
66
+ missing_keys.discard(f"{prefix_target_key}.qweight")
67
+ missing_keys.discard(f"{prefix_target_key}.scales")
68
+ missing_keys.discard(f"{prefix_target_key}.dqweight")
69
+ return {}
70
+
71
+
72
+ class FpQuantDeserialize(ConversionOps):
73
+ def __init__(self, hf_quantizer):
74
+ self.hf_quantizer = hf_quantizer
75
+
76
+ def convert(
77
+ self,
78
+ input_dict: torch.Tensor,
79
+ model: Optional[torch.nn.Module] = None,
80
+ full_layer_name: str | None = None,
81
+ missing_keys: Optional[list[str]] = None,
82
+ **kwargs,
83
+ ) -> dict[str, torch.Tensor]:
84
+ target_key, value = tuple(input_dict.items())[0]
85
+ value = value[0] if isinstance(value, list) else value
86
+ module, _ = get_module_from_name(model, target_key)
87
+ # The module holds either:
88
+ # * `weight` when `store_master_weights=True`
89
+ # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
90
+ # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
91
+ if target_key == ".qweight":
92
+ # Loading a real quantized checkpoint without master weights
93
+ qweight = torch.nn.Parameter(
94
+ value,
95
+ requires_grad=False,
96
+ )
97
+
98
+ return {
99
+ ".qweight": qweight,
100
+ # the way the FPQuantLinear module is designed, these parameters are expected in the model
101
+ # even though they are not used so we need to set them to zeros
102
+ ".weight": torch.nn.Parameter(torch.zeros(0)),
103
+ ".dqweight": torch.nn.Parameter(torch.zeros(0)),
104
+ }
105
+
106
+ if target_key == ".dqweight":
107
+ # Loading a pseudo-quantized checkpoint without master weights
108
+ dqweight = torch.nn.Parameter(value)
109
+
110
+ return {
111
+ ".dqweight": dqweight,
112
+ # the way the FPQuantLinear module ips designed, these parameters are expected in the model
113
+ # even though they are not used so we need to set them to zeros
114
+ ".weight": torch.nn.Parameter(torch.zeros(0)),
115
+ ".qweight": torch.nn.Parameter(torch.zeros(0)),
116
+ ".scales": torch.nn.Parameter(torch.zeros(0)),
117
+ }
118
+
27
119
 
28
120
  def adapt_fp_quant_config(config: FPQuantConfig):
29
121
  if config.forward_dtype == "mxfp4":