transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -40,23 +40,10 @@ class QuantoHfQuantizer(HfQuantizer):
40
40
  Quantizer for the quanto library
41
41
  """
42
42
 
43
- required_packages = ["quanto", "accelerate"]
44
- requires_parameters_quantization = True
45
43
  requires_calibration = False
46
44
 
47
45
  def __init__(self, quantization_config: QuantoConfig, **kwargs):
48
46
  super().__init__(quantization_config, **kwargs)
49
- self.post_init()
50
-
51
- def post_init(self):
52
- r"""
53
- Safety checker
54
- """
55
- if self.quantization_config.activations is not None and not self.pre_quantized:
56
- raise ValueError(
57
- "We don't support quantizing the activations with transformers library."
58
- "Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
59
- )
60
47
 
61
48
  def validate_environment(self, *args, **kwargs):
62
49
  if not is_optimum_quanto_available():
@@ -67,42 +54,22 @@ class QuantoHfQuantizer(HfQuantizer):
67
54
  raise ImportError(
68
55
  "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
69
56
  )
70
-
71
- def update_device_map(self, device_map):
72
- if device_map is None:
73
- device_map = {"": "cpu"}
74
- logger.info(
75
- "The device_map was not initialized. "
76
- "Setting device_map to {'':'cpu'}. "
77
- "If you want to use the model for inference, please set device_map ='auto'"
57
+ device_map = kwargs.get("device_map")
58
+ if isinstance(device_map, dict):
59
+ if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
60
+ raise ValueError(
61
+ "You are attempting to load an model with a device_map that contains a CPU or disk device."
62
+ "This is not supported with quanto when the model is quantized on the fly. "
63
+ "Please remove the CPU or disk device from the device_map."
64
+ )
65
+ if self.quantization_config.activations is not None:
66
+ raise ValueError(
67
+ "We don't support quantizing the activations with transformers library."
68
+ "Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
78
69
  )
79
- return device_map
80
-
81
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
82
- if dtype is None:
83
- logger.info("You did not specify `dtype` in `from_pretrained`. Setting it to `torch.float32`.")
84
- dtype = torch.float32
85
- return dtype
86
-
87
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
88
- if is_optimum_quanto_available():
89
- from optimum.quanto import QModuleMixin
90
-
91
- not_missing_keys = []
92
- for name, module in model.named_modules():
93
- if isinstance(module, QModuleMixin):
94
- for missing in missing_keys:
95
- if (
96
- (name in missing or name in f"{prefix}.{missing}")
97
- and not missing.endswith(".weight")
98
- and not missing.endswith(".bias")
99
- ):
100
- not_missing_keys.append(missing)
101
- return [k for k in missing_keys if k not in not_missing_keys]
102
70
 
103
71
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
104
- if is_optimum_quanto_available():
105
- from optimum.quanto import QModuleMixin
72
+ from optimum.quanto import QModuleMixin
106
73
 
107
74
  module, tensor_name = get_module_from_name(model, param_name)
108
75
  # We only quantize the weights and the bias is not quantized.
@@ -116,21 +83,6 @@ class QuantoHfQuantizer(HfQuantizer):
116
83
  max_memory = {key: val * 0.90 for key, val in max_memory.items()}
117
84
  return max_memory
118
85
 
119
- def create_quantized_param(
120
- self,
121
- model: "PreTrainedModel",
122
- param_value: "torch.Tensor",
123
- param_name: str,
124
- target_device: "torch.device",
125
- **kwargs,
126
- ):
127
- from ..modeling_utils import _load_parameter_into_model
128
-
129
- _load_parameter_into_model(model, param_name, param_value.to(target_device))
130
- module, _ = get_module_from_name(model, param_name)
131
- module.freeze()
132
- module.weight.requires_grad = False
133
-
134
86
  def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
135
87
  from accelerate.utils import CustomDtype
136
88
 
@@ -152,14 +104,18 @@ class QuantoHfQuantizer(HfQuantizer):
152
104
  model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
153
105
  )
154
106
 
155
- model, _ = replace_with_quanto_layers(
107
+ model = replace_with_quanto_layers(
156
108
  model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
157
109
  )
158
- model.config.quantization_config = self.quantization_config
159
110
 
160
111
  @property
161
112
  def is_trainable(self) -> bool:
162
113
  return True
163
114
 
164
- def is_serializable(self, safe_serialization=None):
115
+ def is_serializable(self):
165
116
  return False
117
+
118
+ def get_quantize_ops(self):
119
+ from ..integrations.quanto import QuantoQuantize
120
+
121
+ return QuantoQuantize(self)
@@ -45,12 +45,6 @@ class QuarkHfQuantizer(HfQuantizer):
45
45
  """
46
46
 
47
47
  requires_calibration = True # On-the-fly quantization with quark is not supported for now.
48
- required_packages = ["quark"]
49
-
50
- # Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
51
- # the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
52
- # to load the checkpoints, remapping the keys.
53
- requires_parameters_quantization = True
54
48
 
55
49
  def __init__(self, quantization_config, **kwargs):
56
50
  super().__init__(quantization_config, **kwargs)
@@ -78,19 +72,44 @@ class QuarkHfQuantizer(HfQuantizer):
78
72
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
79
73
  return True
80
74
 
81
- def create_quantized_param(self, model, param, param_name, param_device, **kwargs):
82
- from ..modeling_utils import _load_parameter_into_model
83
-
84
- postfix = param_name.split(".")[-1]
85
-
86
- if postfix in CHECKPOINT_KEYS:
87
- param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
88
-
89
- _load_parameter_into_model(model, param_name, param.to(param_device))
90
-
91
- def is_serializable(self, safe_serialization=None):
75
+ def is_serializable(self):
92
76
  return False
93
77
 
94
78
  @property
95
79
  def is_trainable(self):
96
80
  return False
81
+
82
+ def get_weight_conversions(self):
83
+ from ..core_model_loading import WeightConverter
84
+ from ..integrations.quark import QuarkDeserialize
85
+ # In Quark, quantization is managed through a QParamsLinear module, which holds
86
+ # separate quantizers for the weights, inputs, and biases (e.g. weight_quantizer
87
+ # input_quantizer, bias_quantizer, etc.).
88
+ #
89
+ # When you call `module.state_dict()`, Quark automatically renames the quantizer
90
+ # parameters — for example, `input_quantizer.scale` becomes `input_scale` — and
91
+ # saves them directly at the parent module level.
92
+ #
93
+ # This means we cannot simply rename keys like `weight_scale` back to
94
+ # `weight_quantizer.scale` when loading the state_dict.
95
+ # Otherwise, the `missing_keys` list would still expect keys such as
96
+ # `weight_scale`, `bias_scale`, etc.
97
+ #
98
+ # To fix this, we keep the expected state_dict keys (like `weight_scale`,
99
+ # `bias_scale`, etc.) unchanged, and during the conversion step, we explicitly
100
+ # assign their values into the corresponding quantizer attributes
101
+ # (`weight_quantizer.scale`, `input_quantizer.scale`, and so on).
102
+
103
+ # You can notice here that in target_patterns we use the same key as the source_patterns,
104
+ # this is because we just want to collect the tensors, and we will rename them later in the convert function.
105
+ # We cannot rename directly or else the missing_keys list will not be able to find the tensors.
106
+ converters = []
107
+ for key in CHECKPOINT_KEYS.keys():
108
+ converters.append(
109
+ WeightConverter(
110
+ source_patterns=[key],
111
+ target_patterns=key,
112
+ operations=[QuarkDeserialize(self)],
113
+ )
114
+ )
115
+ return converters
@@ -39,7 +39,6 @@ class SpQRHfQuantizer(HfQuantizer):
39
39
 
40
40
  def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
41
41
  super().__init__(quantization_config, **kwargs)
42
- self.quantization_config = quantization_config
43
42
 
44
43
  def validate_environment(self, *args, **kwargs):
45
44
  if not torch.cuda.is_available():
@@ -71,17 +70,15 @@ class SpQRHfQuantizer(HfQuantizer):
71
70
  self.modules_to_not_convert = self.get_modules_to_not_convert(
72
71
  model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
73
72
  )
74
-
75
73
  replace_with_spqr_linear(
76
74
  model,
77
75
  quantization_config=self.quantization_config,
78
76
  modules_to_not_convert=self.modules_to_not_convert,
79
77
  )
80
- model.config.quantization_config = self.quantization_config
81
78
 
82
79
  @property
83
80
  def is_trainable(self):
84
81
  return False
85
82
 
86
- def is_serializable(self, safe_serialization=None):
83
+ def is_serializable(self):
87
84
  return True
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  import importlib
15
15
  import re
16
- import types
17
- from collections import defaultdict
18
16
  from typing import TYPE_CHECKING
19
17
 
20
18
  from packaging import version
@@ -37,17 +35,12 @@ if is_torch_available():
37
35
 
38
36
  if is_torch_available():
39
37
  import torch
40
- import torch.nn as nn
41
38
 
42
39
  if is_torchao_available():
43
- import torchao
44
-
45
- if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
40
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
46
41
  from torchao.prototype.safetensors.safetensors_support import (
47
42
  flatten_tensor_state_dict,
48
- unflatten_tensor_state_dict,
49
43
  )
50
- from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
51
44
 
52
45
 
53
46
  logger = logging.get_logger(__name__)
@@ -88,11 +81,6 @@ def _linear_extra_repr(self):
88
81
 
89
82
 
90
83
  if is_torchao_available():
91
- SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
92
- torchao.quantization.Float8WeightOnlyConfig,
93
- torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
94
- ]
95
-
96
84
  TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
97
85
 
98
86
 
@@ -101,9 +89,7 @@ class TorchAoHfQuantizer(HfQuantizer):
101
89
  Quantizer for torchao: https://github.com/pytorch/ao/
102
90
  """
103
91
 
104
- requires_parameters_quantization = True
105
92
  requires_calibration = False
106
- required_packages = ["torchao"]
107
93
 
108
94
  def __init__(self, quantization_config, **kwargs):
109
95
  super().__init__(quantization_config, **kwargs)
@@ -166,20 +152,16 @@ class TorchAoHfQuantizer(HfQuantizer):
166
152
  dtype = torch.float32
167
153
  return dtype
168
154
 
169
- def get_state_dict_and_metadata(self, model, safe_serialization: bool | None = False):
155
+ def get_state_dict_and_metadata(self, model):
170
156
  """
171
- If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
172
- the safetensors format.
157
+ We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
173
158
  """
174
- if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
175
- if TORCHAO_VERSION >= version.parse("0.14.0"):
176
- return flatten_tensor_state_dict(model.state_dict())
177
- else:
178
- raise RuntimeError(
179
- f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
180
- )
159
+ if TORCHAO_VERSION >= version.parse("0.15.0"):
160
+ return flatten_tensor_state_dict(model.state_dict()), {}
181
161
  else:
182
- return None, {}
162
+ raise RuntimeError(
163
+ f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
164
+ )
183
165
 
184
166
  def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
185
167
  from accelerate.utils import CustomDtype
@@ -237,9 +219,6 @@ class TorchAoHfQuantizer(HfQuantizer):
237
219
  ]
238
220
  return
239
221
 
240
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
241
- return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
242
-
243
222
  def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
244
223
  if self.pre_quantized:
245
224
  return False
@@ -249,8 +228,6 @@ class TorchAoHfQuantizer(HfQuantizer):
249
228
  # check if the param_name is not in self.modules_to_not_convert
250
229
  if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
251
230
  return False
252
- elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
253
- return True
254
231
 
255
232
  # we only quantize the weight of nn.Linear and nn.Embedding
256
233
  module, tensor_name = get_module_from_name(model, param_name)
@@ -276,148 +253,6 @@ class TorchAoHfQuantizer(HfQuantizer):
276
253
 
277
254
  return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
278
255
 
279
- def create_quantized_param(
280
- self,
281
- model: "PreTrainedModel",
282
- param_value: "torch.Tensor",
283
- param_name: str,
284
- target_device: "torch.device",
285
- **kwargs,
286
- ):
287
- """
288
- Each nn.Linear layer that needs to be quantized is processed here.
289
- First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
290
- """
291
- from torchao.quantization import quantize_
292
-
293
- full_name = param_name
294
- # Those are the pre quantized weights
295
- if ":" in param_name:
296
- param_name = param_name.rsplit(":", 1)[0]
297
- module, tensor_name = get_module_from_name(model, param_name)
298
-
299
- if self.pre_quantized:
300
- # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
301
- # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
302
- is_unsafe_serialization = ":" not in full_name
303
- if tensor_name == "bias" or is_unsafe_serialization:
304
- module._parameters[tensor_name] = torch.nn.Parameter(
305
- param_value.to(target_device), requires_grad=param_value.requires_grad
306
- )
307
- return
308
- # Sanity check for the new serialization format
309
- elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
310
- raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
311
-
312
- # Save the states for later quantization when they are all gathered
313
- if not hasattr(self, "ao_params"):
314
- self.ao_params = defaultdict(dict)
315
- self.ao_params[param_name].update({full_name: param_value})
316
-
317
- # We are ready for quantization in this case (we retrieved all the needed keys)
318
- if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
319
- new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
320
- # Set it
321
- module._parameters[tensor_name] = torch.nn.Parameter(
322
- new_param.to(target_device), requires_grad=new_param.requires_grad
323
- )
324
-
325
- # Free memory
326
- del self.ao_params[param_name]
327
-
328
- # Add repr to the module
329
- if isinstance(module, nn.Linear):
330
- module.extra_repr = types.MethodType(_linear_extra_repr, module)
331
- else:
332
- module._parameters[tensor_name] = torch.nn.Parameter(
333
- param_value, requires_grad=param_value.requires_grad
334
- ).to(target_device)
335
- # if we are quantizing tied parameters, to avoid tying the quantized weights
336
- # the correct order to do it is
337
- # 1. load the weight to model
338
- # 2. run tie_weights to populate the weights
339
- # 3. quantize
340
- input_embed = model.get_input_embeddings()
341
- if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
342
- model.tie_weights()
343
- setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
344
-
345
- # handle FqnToConfig, introduced in torchao 0.15.0+
346
- if self.quantization_config._get_ao_version() >= version.Version("0.15.0"):
347
- from torchao.quantization import FqnToConfig
348
-
349
- config = self.quantization_config.get_apply_tensor_subclass()
350
- if isinstance(config, FqnToConfig):
351
- module_fqn, top_level_param_name = param_name.rsplit(".", 1)
352
- c = None
353
- if param_name in config.fqn_to_config:
354
- assert not module_fqn.startswith("re:"), (
355
- "param fqn should not start with`re:`, which is used for specifying regex"
356
- )
357
- c = config.module_fqn_to_config[param_name]
358
- elif module_fqn in config.fqn_to_config:
359
- assert not module_fqn.startswith("re:"), (
360
- "module fqn should not start with`re:`, which is used for specifying regex"
361
- )
362
- c = config.module_fqn_to_config[module_fqn]
363
- # regex match module and param
364
- else:
365
- for maybe_module_fqn_pattern in config.fqn_to_config:
366
- # if key doesn't start with re, it is an exact fqn key, so we don't regex match
367
- if not maybe_module_fqn_pattern.startswith("re:"):
368
- continue
369
- # see if param matches first
370
- elif re.fullmatch(maybe_module_fqn_pattern[3:], param_name):
371
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
372
- break
373
- elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
374
- # we'll apply the config for first fully matched pattern
375
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
376
- break
377
- else:
378
- c = config.module_fqn_to_config.get("_default", None)
379
-
380
- if c is not None:
381
- if top_level_param_name == "weight":
382
- # we can apply the module config directly
383
- quantize_(module, c, (lambda x, fqn: True))
384
- else:
385
- # need to apply to custom param name
386
- custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
387
- quantize_(module, custom_param_fqn_config, filter_fn=None)
388
- return
389
-
390
- # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
391
- # TODO deprecate this when we deprecate ModuleFqnToConfig
392
- elif self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
393
- from torchao.quantization import ModuleFqnToConfig
394
-
395
- config = self.quantization_config.get_apply_tensor_subclass()
396
- if isinstance(config, ModuleFqnToConfig):
397
- module_fqn, _ = param_name.rsplit(".", 1)
398
- c = None
399
- if module_fqn in config.module_fqn_to_config:
400
- assert not module_fqn.startswith("re:"), (
401
- "module fqn should not start with`re:`, which is used for specifying regex"
402
- )
403
- c = config.module_fqn_to_config[module_fqn]
404
- else:
405
- for maybe_module_fqn_pattern in config.module_fqn_to_config:
406
- if not maybe_module_fqn_pattern.startswith("re:"):
407
- continue
408
- elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
409
- # we'll apply the config for first fully matched pattern
410
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
411
- break
412
- else:
413
- c = config.module_fqn_to_config.get("_default", None)
414
- if c is not None:
415
- # filter_fn: not filtering out any modules
416
- quantize_(module, c, filter_fn=lambda x, fqn: True)
417
- return
418
-
419
- quantize_(module, self.quantization_config.get_apply_tensor_subclass())
420
-
421
256
  def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
422
257
  """
423
258
  Setting model attributes and/or converting model before weights loading. At this point
@@ -450,30 +285,13 @@ class TorchAoHfQuantizer(HfQuantizer):
450
285
  return model
451
286
  return
452
287
 
453
- def is_serializable(self, safe_serialization=None) -> bool:
454
- if safe_serialization:
455
- _is_torchao_serializable = type(
456
- self.quantization_config.quant_type
457
- ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
458
- if not _is_torchao_serializable:
459
- logger.warning(
460
- f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
461
- and torchao version >= 0.14.0, please set `safe_serialization` to False for \
462
- {type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
463
- )
464
- return _is_torchao_serializable
465
-
466
- _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
467
- "0.25.0"
468
- )
469
- if not _is_torchao_serializable:
470
- logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
471
- if self.offload and self.quantization_config.modules_to_not_convert is None:
288
+ def is_serializable(self) -> bool:
289
+ _is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0")
290
+ if not TORCHAO_VERSION >= version.parse("0.15.0"):
472
291
  logger.warning(
473
- "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
474
- "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
292
+ "torchao quantized model only supports serialization for torchao version >= 0.15.0, please upgrade "
293
+ "your version to save the quantized model"
475
294
  )
476
- return False
477
295
  return _is_torchao_serializable
478
296
 
479
297
  def get_accelerator_warm_up_factor(self):
@@ -548,15 +366,18 @@ class TorchAoHfQuantizer(HfQuantizer):
548
366
  if self.pre_quantized:
549
367
  return [
550
368
  WeightConverter(
551
- source_patterns=["weight:qdata", "weight:scale", "weight:zero_point"],
552
- target_patterns="weight",
553
- operations=[TorchAoDeserialize(self)],
554
- ),
555
- WeightConverter(
556
- source_patterns=["weight:_data"],
369
+ # TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
370
+ # note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
371
+ # thus, the order of source_patterns is intentional
372
+ source_patterns=[
373
+ "_weight_qdata",
374
+ "_weight_scale_and_zero",
375
+ "_weight_scale",
376
+ "_weight_zero_point",
377
+ "_weight_act_pre_scale",
378
+ ],
557
379
  target_patterns="weight",
558
380
  operations=[TorchAoDeserialize(self)],
559
381
  ),
560
- # used for unsafe serialization
561
382
  ]
562
383
  return []
@@ -35,11 +35,9 @@ class VptqHfQuantizer(HfQuantizer):
35
35
  """
36
36
 
37
37
  requires_calibration = True
38
- required_packages = ["vptq"]
39
38
 
40
39
  def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
41
40
  super().__init__(quantization_config, **kwargs)
42
- self.quantization_config = quantization_config
43
41
 
44
42
  def validate_environment(self, *args, **kwargs):
45
43
  if not is_accelerate_available():
@@ -48,21 +46,15 @@ class VptqHfQuantizer(HfQuantizer):
48
46
  if not is_vptq_available():
49
47
  raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
50
48
 
49
+ if not torch.cuda.is_available():
50
+ raise RuntimeError("GPU is required to run VTPQ quantized model.")
51
+
51
52
  def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
52
53
  if dtype is None:
53
- if torch.cuda.is_available():
54
- dtype = torch.float16
55
- logger.info(
56
- "CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
57
- )
58
- else:
59
- import vptq
60
-
61
- device_availability = getattr(vptq, "device_availability", lambda device: False)
62
- if device_availability("cpu") is True:
63
- raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
64
- dtype = torch.float32
65
- logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
54
+ dtype = torch.float16
55
+ logger.info(
56
+ "Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
57
+ )
66
58
  return dtype
67
59
 
68
60
  def _process_model_before_weight_loading(
@@ -71,26 +63,20 @@ class VptqHfQuantizer(HfQuantizer):
71
63
  keep_in_fp32_modules: list[str] | None = None,
72
64
  **kwargs,
73
65
  ):
74
- """
75
- we don't have param like modules_to_not_convert to indicate which layers should not be quantized
76
- because `quantization_config` include the layers that should be quantized
77
- """
78
66
  from ..integrations import replace_with_vptq_linear
79
67
 
80
68
  self.modules_to_not_convert = self.get_modules_to_not_convert(
81
69
  model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
82
70
  )
83
-
84
71
  replace_with_vptq_linear(
85
72
  model,
86
73
  quantization_config=self.quantization_config,
87
74
  modules_to_not_convert=self.modules_to_not_convert,
88
75
  )
89
- model.config.quantization_config = self.quantization_config
90
76
 
91
77
  @property
92
78
  def is_trainable(self) -> bool:
93
79
  return False
94
80
 
95
- def is_serializable(self, safe_serialization=None):
81
+ def is_serializable(self):
96
82
  return True
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import re
14
15
  from typing import Any
15
16
 
16
17
 
@@ -19,3 +20,22 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]:
19
20
  module_name, tensor_name = tensor_name.rsplit(".", 1)
20
21
  module = module.get_submodule(module_name)
21
22
  return module, tensor_name
23
+
24
+
25
+ def should_convert_module(full_name, patterns: list[str] | None = None):
26
+ if patterns is None:
27
+ return True
28
+
29
+ # We should avoid converting in the following situations:
30
+ # 1. The pattern appears as a prefix followed by a dot in `full_name`
31
+ # (e.g., "model.decoder.layer.11." matches "model.decoder.layer.11.attn.weight").
32
+ # 2. The pattern matches `full_name` exactly or via regex
33
+ # (e.g., "lm_head" matches "lm_head"; "model.decoder.layer.*" matches "model.decoder.layer.11.attn.weight").
34
+ # 3. `full_name` ends with the pattern
35
+ # (e.g., "fc1" matches "model.decoder.layers.23.fc1").
36
+
37
+ should_not_convert = any(
38
+ re.match(f"{key}\\.", full_name) or re.match(f"{key}", full_name) or full_name.endswith(key)
39
+ for key in patterns
40
+ )
41
+ return not should_not_convert