transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from collections import defaultdict
16
15
  from typing import TYPE_CHECKING
17
16
 
18
17
  from ..integrations import prepare_for_hqq_linear
@@ -49,10 +48,7 @@ class HqqHfQuantizer(HfQuantizer):
49
48
  nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
50
49
  """
51
50
 
52
- use_keep_in_fp32_modules = False
53
- requires_parameters_quantization = True
54
51
  requires_calibration = False
55
- required_packages = ["hqq"]
56
52
 
57
53
  def __init__(self, quantization_config, **kwargs):
58
54
  if not is_hqq_available():
@@ -83,73 +79,67 @@ class HqqHfQuantizer(HfQuantizer):
83
79
  else:
84
80
  self.using_multi_gpu = len(set(device_map.values())) > 1
85
81
 
86
- def update_missing_keys(
87
- self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs
88
- ) -> list[str]:
89
- if self.pre_quantized:
90
- return [key for key in missing_keys if ("weight" not in key)]
91
- else:
92
- return missing_keys
93
-
94
- # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
95
- def update_expected_keys(
96
- self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
97
- ) -> list[str]:
98
- if not self.pre_quantized:
99
- return expected_keys
100
-
101
- # Collects all quantizable (linear) layers
102
- def _find_hqq_quantizable_layers(model, layers):
103
- for name, module in model.named_children():
104
- if isinstance(module, (torch.nn.Linear)):
105
- layers.add(module.name)
106
- _find_hqq_quantizable_layers(module, layers)
107
-
108
- new_keys = set(expected_keys)
109
-
110
- # Name modules
111
- for name, module in model.named_modules():
112
- module.name = name
113
-
114
- # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
115
- _valid_modules = set()
116
- _find_hqq_quantizable_layers(model, _valid_modules)
117
-
118
- # Remove skipped modules
119
- _skipped_modules = set()
120
- for _module in _valid_modules:
121
- for _skip_module in model.config.quantization_config["skip_modules"]:
122
- if _skip_module in _module:
123
- _skipped_modules.add(_module)
124
- _valid_modules -= _skipped_modules
125
-
126
- # Append new expected layers based on _ref_keys
127
- _ref_keys = HQQLinear(
128
- linear_layer=None,
129
- quant_config=None,
130
- compute_dtype=torch.float16,
131
- device="cpu",
132
- del_orig=False,
133
- ).state_dict_keys() - {"bias"}
134
-
135
- # Clean-up
136
- _rm_keys = set()
137
- for key in new_keys:
138
- if any(_module in key for _module in _valid_modules):
139
- _rm_keys.add(key)
140
- new_keys -= _rm_keys
141
- # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
142
-
143
- # Re-populate Linear/HQQLinear
144
- for _module in _valid_modules:
145
- if _module + ".weight" in loaded_keys:
146
- new_keys.add(_module + ".weight")
147
- else:
148
- new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
149
- if _module + ".bias" in loaded_keys:
150
- new_keys.add(_module + ".bias")
151
-
152
- return list(new_keys)
82
+ # TODO: to remove
83
+ # Kept here in case we see some interest in adding support for it
84
+ # # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
85
+ # def update_expected_keys(
86
+ # self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
87
+ # ) -> list[str]:
88
+ # if not self.pre_quantized:
89
+ # return expected_keys
90
+
91
+ # # Collects all quantizable (linear) layers
92
+ # def _find_hqq_quantizable_layers(model, layers):
93
+ # for name, module in model.named_children():
94
+ # if isinstance(module, (torch.nn.Linear)):
95
+ # layers.add(module.name)
96
+ # _find_hqq_quantizable_layers(module, layers)
97
+
98
+ # new_keys = set(expected_keys)
99
+
100
+ # # Name modules
101
+ # for name, module in model.named_modules():
102
+ # module.name = name
103
+
104
+ # # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
105
+ # _valid_modules = set()
106
+ # _find_hqq_quantizable_layers(model, _valid_modules)
107
+
108
+ # # Remove skipped modules
109
+ # _skipped_modules = set()
110
+ # for _module in _valid_modules:
111
+ # for _skip_module in model.config.quantization_config["skip_modules"]:
112
+ # if _skip_module in _module:
113
+ # _skipped_modules.add(_module)
114
+ # _valid_modules -= _skipped_modules
115
+
116
+ # # Append new expected layers based on _ref_keys
117
+ # _ref_keys = HQQLinear(
118
+ # linear_layer=None,
119
+ # quant_config=None,
120
+ # compute_dtype=torch.float16,
121
+ # device="cpu",
122
+ # del_orig=False,
123
+ # ).state_dict_keys() - {"bias"}
124
+
125
+ # # Clean-up
126
+ # _rm_keys = set()
127
+ # for key in new_keys:
128
+ # if any(_module in key for _module in _valid_modules):
129
+ # _rm_keys.add(key)
130
+ # new_keys -= _rm_keys
131
+ # # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
132
+
133
+ # # Re-populate Linear/HQQLinear
134
+ # for _module in _valid_modules:
135
+ # if _module + ".weight" in loaded_keys:
136
+ # new_keys.add(_module + ".weight")
137
+ # else:
138
+ # new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
139
+ # if _module + ".bias" in loaded_keys:
140
+ # new_keys.add(_module + ".bias")
141
+
142
+ # return list(new_keys)
153
143
 
154
144
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
155
145
  module, _ = get_module_from_name(model, param_name)
@@ -157,87 +147,88 @@ class HqqHfQuantizer(HfQuantizer):
157
147
  # `create_quantized_param`, even when `self.is_quantized == True`
158
148
  return isinstance(module, torch.nn.Linear)
159
149
 
160
- def create_quantized_param(
161
- self,
162
- model: "PreTrainedModel",
163
- param_value: "torch.Tensor",
164
- param_name: str,
165
- target_device: "torch.device",
166
- **kwargs,
167
- ):
168
- module, tensor_name = get_module_from_name(model, param_name)
169
- module_name = param_name.rsplit(".", 1)[0]
170
- parent_module, node = get_module_from_name(model, module_name)
171
-
172
- quant_config = model.config.quantization_config["quant_config"]
173
- skip_modules = model.config.quantization_config["skip_modules"]
174
-
175
- # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
176
- if any(skip_module in module.name for skip_module in skip_modules):
177
- module.load_state_dict(
178
- {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
179
- )
180
- return
181
-
182
- # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
183
- if self.pre_quantized:
184
- # Save them for later
185
- if not hasattr(self, "hqq_params"):
186
- self.hqq_params = defaultdict(dict)
187
- self.hqq_params[module_name].update({tensor_name: param_value})
188
- hqq_params = self.hqq_params[module_name]
189
-
190
- # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
191
- # hqq does not support it...)
192
- if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
193
- hqq_layer = HQQLinear(
194
- linear_layer=None,
195
- quant_config=None,
196
- compute_dtype=self.dtype,
197
- device=target_device,
198
- del_orig=False,
199
- )
200
- hqq_layer.load_state_dict(hqq_params)
201
-
202
- if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
203
- hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
204
- if self.using_multi_gpu:
205
- hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
206
-
207
- setattr(parent_module, node, hqq_layer)
208
- del self.hqq_params[module_name], module
209
- return
210
-
211
- # Load param in the module (without caring about device or dtype, it will be changed later)
212
- module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
213
-
214
- # If both the weight and bias have already been loaded, time to quantize!
215
- module_is_ready = module.weight.device.type != "meta" and (
216
- module.bias is None or module.bias.device.type != "meta"
217
- )
218
-
219
- if module_is_ready:
220
- module_tag = ".".join(module.name.split(".")[-2:])
221
- if "weight_quant_params" in quant_config:
222
- module_quant_config = quant_config
223
- elif module_tag in quant_config:
224
- module_quant_config = quant_config[module_tag]
225
-
226
- hqq_layer = HQQLinear(
227
- module,
228
- quant_config=module_quant_config,
229
- compute_dtype=self.dtype,
230
- device=target_device,
231
- del_orig=True,
232
- )
233
-
234
- if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
235
- hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
236
-
237
- if self.using_multi_gpu:
238
- hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
239
-
240
- setattr(parent_module, node, hqq_layer)
150
+ # TODO: to remove
151
+ # def create_quantized_param(
152
+ # self,
153
+ # model: "PreTrainedModel",
154
+ # param_value: "torch.Tensor",
155
+ # param_name: str,
156
+ # target_device: "torch.device",
157
+ # **kwargs,
158
+ # ):
159
+ # module, tensor_name = get_module_from_name(model, param_name)
160
+ # module_name = param_name.rsplit(".", 1)[0]
161
+ # parent_module, node = get_module_from_name(model, module_name)
162
+
163
+ # quant_config = model.config.quantization_config["quant_config"]
164
+ # skip_modules = model.config.quantization_config["skip_modules"]
165
+
166
+ # # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
167
+ # if any(skip_module in module.name for skip_module in skip_modules):
168
+ # module.load_state_dict(
169
+ # {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
170
+ # )
171
+ # return
172
+
173
+ # # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
174
+ # if self.pre_quantized:
175
+ # # Save them for later
176
+ # if not hasattr(self, "hqq_params"):
177
+ # self.hqq_params = defaultdict(dict)
178
+ # self.hqq_params[module_name].update({tensor_name: param_value})
179
+ # hqq_params = self.hqq_params[module_name]
180
+
181
+ # # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
182
+ # # hqq does not support it...)
183
+ # if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
184
+ # hqq_layer = HQQLinear(
185
+ # linear_layer=None,
186
+ # quant_config=None,
187
+ # compute_dtype=self.dtype,
188
+ # device=target_device,
189
+ # del_orig=False,
190
+ # )
191
+ # hqq_layer.load_state_dict(hqq_params)
192
+
193
+ # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
194
+ # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
195
+ # if self.using_multi_gpu:
196
+ # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
197
+
198
+ # setattr(parent_module, node, hqq_layer)
199
+ # del self.hqq_params[module_name], module
200
+ # return
201
+
202
+ # # Load param in the module (without caring about device or dtype, it will be changed later)
203
+ # module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
204
+
205
+ # # If both the weight and bias have already been loaded, time to quantize!
206
+ # module_is_ready = module.weight.device.type != "meta" and (
207
+ # module.bias is None or module.bias.device.type != "meta"
208
+ # )
209
+
210
+ # if module_is_ready:
211
+ # module_tag = ".".join(module.name.split(".")[-2:])
212
+ # if "weight_quant_params" in quant_config:
213
+ # module_quant_config = quant_config
214
+ # elif module_tag in quant_config:
215
+ # module_quant_config = quant_config[module_tag]
216
+
217
+ # hqq_layer = HQQLinear(
218
+ # module,
219
+ # quant_config=module_quant_config,
220
+ # compute_dtype=self.dtype,
221
+ # device=target_device,
222
+ # del_orig=True,
223
+ # )
224
+
225
+ # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
226
+ # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
227
+
228
+ # if self.using_multi_gpu:
229
+ # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
230
+
231
+ # setattr(parent_module, node, hqq_layer)
241
232
 
242
233
  def _patch_layer_for_multigpu(self, hqq_layer):
243
234
  def forward_with_device(self, x):
@@ -263,7 +254,7 @@ class HqqHfQuantizer(HfQuantizer):
263
254
  model.is_hqq_serializable = self.is_serializable()
264
255
  return model
265
256
 
266
- def is_serializable(self, safe_serialization=None):
257
+ def is_serializable(self):
267
258
  return True
268
259
 
269
260
  @property
@@ -43,14 +43,10 @@ class Mxfp4HfQuantizer(HfQuantizer):
43
43
  FP4 quantization using fbgemm kernels
44
44
  """
45
45
 
46
- requires_parameters_quantization = True
47
46
  requires_calibration = False
48
47
 
49
- required_packages = ["accelerate"]
50
-
51
48
  def __init__(self, quantization_config, **kwargs):
52
49
  super().__init__(quantization_config, **kwargs)
53
- self.quantization_config = quantization_config
54
50
  self.triton_kernels_hub = None
55
51
 
56
52
  def _lazy_import_kernels(self):
@@ -74,7 +70,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
74
70
  if self.quantization_config.dequantize:
75
71
  return
76
72
 
77
- if not (torch.cuda.is_available() or torch.xpu.is_available()):
73
+ if not torch.cuda.is_available() and not torch.xpu.is_available():
78
74
  if self.pre_quantized:
79
75
  logger.warning_once(
80
76
  "Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
@@ -131,12 +127,8 @@ class Mxfp4HfQuantizer(HfQuantizer):
131
127
  "You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
132
128
  "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
133
129
  )
134
- elif device_map is not None:
135
- if (
136
- not self.pre_quantized
137
- and isinstance(device_map, dict)
138
- and ("cpu" in device_map.values() or "disk" in device_map.values())
139
- ):
130
+ elif isinstance(device_map, dict):
131
+ if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
140
132
  raise ValueError(
141
133
  "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
142
134
  "This is not supported when the model is quantized on the fly. "
@@ -157,159 +149,30 @@ class Mxfp4HfQuantizer(HfQuantizer):
157
149
 
158
150
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
159
151
  from ..integrations import Mxfp4GptOssExperts
160
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
161
152
 
162
- if self.pre_quantized:
163
- return False
164
- # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
165
- if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
166
- module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
167
- else:
168
- module, tensor_name = get_module_from_name(model, param_name)
169
- if isinstance(module, Mxfp4GptOssExperts) or (
170
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
171
- ):
153
+ module, tensor_name = get_module_from_name(model, param_name)
154
+ if isinstance(module, Mxfp4GptOssExperts):
172
155
  if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
173
156
  return False
174
157
  return True
175
158
  return False
176
159
 
177
- def create_quantized_param(
178
- self,
179
- model: "PreTrainedModel",
180
- param_value: "torch.Tensor",
181
- param_name: str,
182
- target_device: "torch.device",
183
- **kwargs,
184
- ):
185
- from ..integrations import (
186
- Mxfp4GptOssExperts,
187
- dequantize,
188
- load_and_swizzle_mxfp4,
189
- quantize_to_mxfp4,
190
- swizzle_mxfp4,
191
- )
192
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
193
-
194
- if not self.pre_quantized:
195
- triton_kernels_hub = self._lazy_import_kernels()
196
- module, _ = get_module_from_name(model, param_name)
197
- with torch.device(target_device):
198
- if isinstance(module, Mxfp4GptOssExperts):
199
- triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
200
- PrecisionConfig, FlexCtx, InFlexData = (
201
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
202
- triton_kernels_hub.matmul_ogs.FlexCtx,
203
- triton_kernels_hub.matmul_ogs.InFlexData,
204
- )
205
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
206
- triton_weight_tensor, weight_scale, triton_kernels_hub
207
- )
208
-
209
- proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
210
- setattr(module, proj, triton_weight_tensor)
211
- setattr(
212
- module,
213
- f"{proj}_precision_config",
214
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
215
- )
216
-
217
- delattr(module, f"{proj}_blocks")
218
- delattr(module, f"{proj}_scales")
219
-
220
- # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
221
- else:
222
- # This is when loading a quantized model (blocks and scales exist)
223
- empty_param = kwargs.get("empty_param")
224
- casting_dtype = kwargs.get("casting_dtype")
225
- to_contiguous = kwargs.get("to_contiguous")
226
- rank = kwargs.get("rank")
227
- device_mesh = kwargs.get("device_mesh")
228
- if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
229
- # blocks and scales have the same length that's why this works for both
230
- module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
231
- else:
232
- module, _ = get_module_from_name(model, param_name)
233
-
234
- shard_kwargs = {
235
- "empty_param": empty_param,
236
- "casting_dtype": casting_dtype,
237
- "to_contiguous": to_contiguous,
238
- "rank": rank,
239
- "device_mesh": device_mesh,
240
- "model": model,
241
- }
242
-
243
- if isinstance(module, Mxfp4GptOssExperts) or (
244
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
245
- ):
246
- if self.quantization_config.dequantize:
247
- # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
248
- # so we only have the original param name
249
- dq_param_name = param_name[: -len("_blocks")]
250
- dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
251
- else:
252
- load_and_swizzle_mxfp4(
253
- module,
254
- param_name,
255
- param_value,
256
- target_device,
257
- self._lazy_import_kernels(),
258
- **shard_kwargs,
259
- )
260
-
261
160
  def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
262
- # we are not really dequantizing, we are just removing everything related to quantization here
263
- if self.quantization_config.dequantize:
264
- self.remove_quantization_config(model)
265
161
  # clean cache due to triton ops
266
162
  if torch.cuda.is_available():
267
163
  torch.cuda.empty_cache()
268
164
  elif torch.xpu.is_available():
269
165
  torch.xpu.empty_cache()
270
166
 
271
- def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
272
- # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
273
- new_expected_keys = []
274
- for key in expected_keys:
275
- if key.endswith(".mlp.experts.gate_up_proj"):
276
- base = key[: -len("gate_up_proj")]
277
- new_expected_keys.append(base + "gate_up_proj_blocks")
278
- new_expected_keys.append(base + "gate_up_proj_scales")
279
- elif key.endswith(".mlp.experts.down_proj"):
280
- base = key[: -len("down_proj")]
281
- new_expected_keys.append(base + "down_proj_blocks")
282
- new_expected_keys.append(base + "down_proj_scales")
283
- elif not self.pre_quantized:
284
- # in this case, we are quantizing the model so we need to update the keys as we changed the layers
285
- if key.endswith(".mlp.experts.down_proj_blocks"):
286
- base = key[: -len("down_proj_blocks")]
287
- new_expected_keys.append(base + "down_proj")
288
- elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
289
- base = key[: -len("gate_up_proj_blocks")]
290
- new_expected_keys.append(base + "gate_up_proj")
291
- elif key.endswith("scales"):
292
- # we remove it the scales as the checkpoint don't contain them
293
- continue
294
- else:
295
- new_expected_keys.append(key)
296
- else:
297
- new_expected_keys.append(key)
298
- return new_expected_keys
299
-
300
167
  def _process_model_before_weight_loading(
301
168
  self,
302
169
  model: "PreTrainedModel",
303
170
  keep_in_fp32_modules: list[str] | None = None,
171
+ use_kernels: bool = False,
304
172
  **kwargs,
305
173
  ):
306
174
  from ..integrations import replace_with_mxfp4_linear
307
175
 
308
- self.modules_to_not_convert = self.get_modules_to_not_convert(
309
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
310
- )
311
-
312
- use_kernels = kwargs.get("use_kernels", False)
313
176
  # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
314
177
  if use_kernels:
315
178
  logger.warning_once(
@@ -318,30 +181,13 @@ class Mxfp4HfQuantizer(HfQuantizer):
318
181
  )
319
182
  self.quantization_config.dequantize = True
320
183
 
321
- config = model.config
322
- model = replace_with_mxfp4_linear(
323
- model,
324
- modules_to_not_convert=self.modules_to_not_convert,
325
- quantization_config=self.quantization_config,
326
- config=config,
184
+ self.modules_to_not_convert = self.get_modules_to_not_convert(
185
+ model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
327
186
  )
328
187
 
329
- model.config.quantization_config = self.quantization_config
330
-
331
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
332
- from ..integrations import Mxfp4GptOssExperts
333
-
334
- not_missing_keys = []
335
- for name, module in model.named_modules():
336
- if isinstance(module, Mxfp4GptOssExperts):
337
- for missing in missing_keys:
338
- if (
339
- (name in missing or name in f"{prefix}.{missing}")
340
- and not missing.endswith(".weight")
341
- and not missing.endswith(".bias")
342
- ):
343
- not_missing_keys.append(missing)
344
- return [k for k in missing_keys if k not in not_missing_keys]
188
+ model = replace_with_mxfp4_linear(
189
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
190
+ )
345
191
 
346
192
  def update_tp_plan(self, config):
347
193
  if "GptOssConfig" in config.__class__.__name__:
@@ -382,7 +228,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
382
228
  return param_name.replace("down_proj", "down_proj_blocks")
383
229
  return param_name
384
230
 
385
- def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
231
+ def get_state_dict_and_metadata(self, model):
386
232
  from ..integrations import Mxfp4GptOssExperts
387
233
 
388
234
  state_dict = model.state_dict()
@@ -421,7 +267,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
421
267
  metadata = {}
422
268
  return state_dict, metadata
423
269
 
424
- def is_serializable(self, safe_serialization=None):
270
+ def is_serializable(self):
425
271
  return True
426
272
 
427
273
  @property