transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from ..quantizers.quantizers_utils import should_convert_module
1
2
  from ..utils import is_accelerate_available, is_torch_available, logging
2
3
 
3
4
 
@@ -314,113 +315,57 @@ class AutoBitLinear(nn.Linear):
314
315
  return output
315
316
 
316
317
 
317
- def _replace_with_bitnet_linear(
318
- model,
319
- modules_to_not_convert=None,
320
- current_key_name=None,
321
- quantization_config=None,
322
- has_been_replaced=False,
323
- pre_quantized=False,
324
- ):
318
+ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
325
319
  """
326
- Private method that wraps the recursion for module replacement.
320
+ Public method that replaces the linear layers of the given model with bitnet quantized layers.
327
321
 
328
- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
329
- """
330
-
331
- if current_key_name is None:
332
- current_key_name = []
333
-
334
- for name, module in model.named_children():
335
- if current_key_name is None:
336
- current_key_name = []
337
- current_key_name.append(name)
338
-
339
- # Check if the current key is not in the `modules_to_not_convert`
340
- if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
341
- with init_empty_weights():
342
- if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
343
- in_features = module.in_features
344
- out_features = module.out_features
345
- if quantization_config and quantization_config.linear_class == "autobitlinear":
346
- model._modules[name] = AutoBitLinear(
347
- in_features=in_features,
348
- out_features=out_features,
349
- bias=module.bias is not None,
350
- device=module.weight.device,
351
- dtype=module.weight.dtype,
352
- online_quant=(quantization_config.quantization_mode == "online"),
353
- use_rms_norm=quantization_config.use_rms_norm,
354
- rms_norm_eps=quantization_config.rms_norm_eps,
355
- )
356
- if quantization_config.quantization_mode == "offline":
357
- model._modules[name].requires_grad_(False)
358
- else:
359
- model._modules[name] = BitLinear(
360
- in_features=in_features,
361
- out_features=out_features,
362
- bias=module.bias is not None,
363
- device=module.weight.device,
364
- dtype=module.weight.dtype,
365
- use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
366
- rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
367
- )
368
- model._modules[name].requires_grad_(False)
369
- has_been_replaced = True
370
-
371
- if len(list(module.children())) > 0:
372
- _, has_been_replaced = _replace_with_bitnet_linear(
373
- module,
374
- modules_to_not_convert=modules_to_not_convert,
375
- current_key_name=current_key_name,
376
- quantization_config=quantization_config,
377
- has_been_replaced=has_been_replaced,
378
- )
379
- # Remove the last key for recursion
380
- current_key_name.pop(-1)
381
- return model, has_been_replaced
382
-
383
-
384
- def replace_with_bitnet_linear(
385
- model,
386
- modules_to_not_convert=None,
387
- current_key_name=None,
388
- quantization_config=None,
389
- pre_quantized=False,
390
- ):
391
- """
392
- A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
393
-
394
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
395
- be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
396
- CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
397
-
398
- Parameters:
322
+ Args:
399
323
  model (`torch.nn.Module`):
400
- Input model or `torch.nn.Module` as the function is run recursively.
401
- modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
402
- Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
403
- for numerical stability reasons.
404
- current_key_name (`list[`str`]`, *optional*):
405
- An array to track the current key of the recursion. This is used to check whether the current key (part of
406
- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
407
- `disk`).
324
+ The model to convert, can be any `torch.nn.Module` instance.
325
+ modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
326
+ A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
327
+ converted.
328
+ quantization_config (`BitNetConfig`):
329
+ The quantization config object that contains the quantization parameters.
408
330
  """
409
- modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
410
- if quantization_config and quantization_config.modules_to_not_convert is not None:
411
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
412
- modules_to_not_convert = list(set(modules_to_not_convert))
413
- model, has_been_replaced = _replace_with_bitnet_linear(
414
- model,
415
- modules_to_not_convert,
416
- current_key_name,
417
- quantization_config,
418
- pre_quantized=pre_quantized,
419
- )
331
+
332
+ has_been_replaced = False
333
+ # we need this to correctly materialize the weights during quantization
334
+ for module_name, module in model.named_modules():
335
+ if not should_convert_module(module_name, modules_to_not_convert):
336
+ continue
337
+ with init_empty_weights():
338
+ if isinstance(module, nn.Linear):
339
+ if quantization_config and quantization_config.linear_class == "autobitlinear":
340
+ new_module = AutoBitLinear(
341
+ in_features=module.in_features,
342
+ out_features=module.out_features,
343
+ bias=module.bias is not None,
344
+ device=module.weight.device,
345
+ dtype=module.weight.dtype,
346
+ online_quant=(quantization_config.quantization_mode == "online"),
347
+ use_rms_norm=quantization_config.use_rms_norm,
348
+ rms_norm_eps=quantization_config.rms_norm_eps,
349
+ )
350
+ if quantization_config.quantization_mode == "offline":
351
+ new_module.requires_grad_(False)
352
+ else:
353
+ new_module = BitLinear(
354
+ in_features=module.in_features,
355
+ out_features=module.out_features,
356
+ bias=module.bias is not None,
357
+ device=module.weight.device,
358
+ dtype=module.weight.dtype,
359
+ use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
360
+ rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
361
+ )
362
+ new_module.requires_grad_(False)
363
+ model.set_submodule(module_name, new_module)
364
+ has_been_replaced = True
420
365
 
421
366
  if not has_been_replaced:
422
367
  logger.warning(
423
- "You are loading your model using bitnet but no linear modules were found in your model."
368
+ "You are loading your model using eetq but no linear modules were found in your model."
424
369
  " Please double check your model architecture, or submit an issue on github if you think this is"
425
370
  " a bug."
426
371
  )
@@ -1,8 +1,7 @@
1
1
  import inspect
2
- from inspect import signature
3
2
 
4
3
  from ..core_model_loading import ConversionOps
5
- from ..quantizers.quantizers_utils import get_module_from_name
4
+ from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
6
5
  from ..utils import (
7
6
  get_available_devices,
8
7
  is_accelerate_available,
@@ -44,7 +43,7 @@ class Bnb4bitQuantize(ConversionOps):
44
43
  we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
45
44
  """
46
45
  value = list(input_dict.values())[0]
47
- value = value[0] if isinstance(value, list) else value
46
+ value = value[0]
48
47
 
49
48
  # update param name to get the weights instead of the quantized stats
50
49
  module, _ = get_module_from_name(model, full_layer_name)
@@ -156,134 +155,77 @@ class Bnb8bitDeserialize(ConversionOps):
156
155
  return {key_weight: new_value}
157
156
 
158
157
 
159
- def _replace_with_bnb_linear(
160
- model,
161
- modules_to_not_convert=None,
162
- current_key_name=None,
158
+ def replace_with_bnb_linear(
159
+ model: torch.nn.Module,
160
+ modules_to_not_convert: list[str] | None = None,
163
161
  quantization_config=None,
164
- has_been_replaced=False,
165
162
  pre_quantized=False,
166
163
  ):
167
164
  """
168
- Private method that wraps the recursion for module replacement.
165
+ A helper function to replace all `torch.nn.Linear` modules by bnb modules from the `bitsandbytes` library.
169
166
 
170
- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
167
+ Args:
168
+ model (`torch.nn.Module`):
169
+ The model to convert, can be any `torch.nn.Module` instance.
170
+ modules_to_not_convert (`list[str]`, defaults to `None`):
171
+ A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
172
+ converted.
173
+ quantization_config (`BitsAndBytesConfig`):
174
+ The quantization config object that contains the quantization parameters.
175
+ pre_quantized (`book`, defaults to `False`):
176
+ Whether the model is pre-quantized or not
171
177
  """
172
- for name, module in model.named_children():
173
- if current_key_name is None:
174
- current_key_name = []
175
- current_key_name.append(name)
176
-
177
- if (isinstance(module, (nn.Linear, Conv1D))) and name not in modules_to_not_convert:
178
- # Check if the current key is not in the `modules_to_not_convert`
179
- current_key_name_str = ".".join(current_key_name)
180
- if not any(
181
- (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
182
- ):
183
- with init_empty_weights():
184
- if isinstance(module, Conv1D):
185
- in_features, out_features = module.weight.shape
186
- else:
187
- in_features = module.in_features
188
- out_features = module.out_features
189
-
190
- if quantization_config.quantization_method() == "llm_int8":
191
- new_module = bnb.nn.Linear8bitLt(
192
- in_features,
193
- out_features,
194
- module.bias is not None,
195
- has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
196
- threshold=quantization_config.llm_int8_threshold,
178
+ has_been_replaced = False
179
+ # we need this to correctly materialize the weights during quantization
180
+ for module_name, module in model.named_modules():
181
+ if not should_convert_module(module_name, modules_to_not_convert):
182
+ continue
183
+ new_module = None
184
+ with init_empty_weights():
185
+ if isinstance(module, (nn.Linear, Conv1D)):
186
+ if isinstance(module, Conv1D):
187
+ in_features, out_features = module.weight.shape
188
+ else:
189
+ in_features = module.in_features
190
+ out_features = module.out_features
191
+ if quantization_config.quantization_method() == "llm_int8":
192
+ new_module = bnb.nn.Linear8bitLt(
193
+ in_features,
194
+ out_features,
195
+ module.bias is not None,
196
+ has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
197
+ threshold=quantization_config.llm_int8_threshold,
198
+ )
199
+ if pre_quantized:
200
+ # this is kind of an edge case when supporting both loading and quantization ...
201
+ # we need to set the right dtype as we cast the checkpoint with the dtype of the meta model
202
+ new_module.weight.data = new_module.weight.data.to(dtype=torch.int8)
203
+ else:
204
+ new_module = bnb.nn.Linear4bit(
205
+ in_features,
206
+ out_features,
207
+ module.bias is not None,
208
+ quantization_config.bnb_4bit_compute_dtype,
209
+ compress_statistics=quantization_config.bnb_4bit_use_double_quant,
210
+ quant_type=quantization_config.bnb_4bit_quant_type,
211
+ quant_storage=quantization_config.bnb_4bit_quant_storage,
212
+ )
213
+ if pre_quantized:
214
+ # same here
215
+ new_module.weight.data = new_module.weight.data.to(
216
+ dtype=quantization_config.bnb_4bit_quant_storage
197
217
  )
198
- if pre_quantized:
199
- new_module.weight.data = new_module.weight.data.to(dtype=torch.int8)
200
- model._modules[name] = new_module
201
- has_been_replaced = True
202
- else:
203
- if (
204
- quantization_config.llm_int8_skip_modules is not None
205
- and name in quantization_config.llm_int8_skip_modules
206
- ):
207
- pass
208
- else:
209
- extra_kwargs = (
210
- {"quant_storage": quantization_config.bnb_4bit_quant_storage}
211
- if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
212
- else {}
213
- )
214
- new_module = bnb.nn.Linear4bit(
215
- in_features,
216
- out_features,
217
- module.bias is not None,
218
- quantization_config.bnb_4bit_compute_dtype,
219
- compress_statistics=quantization_config.bnb_4bit_use_double_quant,
220
- quant_type=quantization_config.bnb_4bit_quant_type,
221
- **extra_kwargs,
222
- )
223
- if pre_quantized:
224
- # this is kind of an edge case when supporting both loading and quantization ...
225
- # we need to set the right dtype as we cast the checkpoint with the dtype of the meta model
226
- new_module.weight.data = new_module.weight.data.to(dtype=torch.uint8)
227
- model._modules[name] = new_module
228
- has_been_replaced = True
218
+ if new_module is not None:
229
219
  # Store the module class in case we need to transpose the weight later
230
- model._modules[name].source_cls = type(module)
220
+ new_module.source_cls = type(module)
231
221
  # Force requires grad to False to avoid unexpected errors
232
- model._modules[name].requires_grad_(False)
233
- if len(list(module.children())) > 0:
234
- _, has_been_replaced = _replace_with_bnb_linear(
235
- module,
236
- modules_to_not_convert,
237
- current_key_name,
238
- quantization_config,
239
- has_been_replaced=has_been_replaced,
240
- pre_quantized=pre_quantized,
241
- )
242
- # Remove the last key for recursion
243
- current_key_name.pop(-1)
244
- return model, has_been_replaced
245
-
246
-
247
- def replace_with_bnb_linear(
248
- model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
249
- ):
250
- """
251
- A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
252
- library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
253
- 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
254
- version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
255
- bitsandbytes`
256
-
257
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
258
- be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
259
- CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
260
- matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
261
- (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
262
- predictive degradation is possible for very large models (>=176B parameters).
263
-
264
- Parameters:
265
- model (`torch.nn.Module`):
266
- Input model or `torch.nn.Module` as the function is run recursively.
267
- modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
268
- Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
269
- for numerical stability reasons.
270
- current_key_name (`list[`str`]`, *optional*):
271
- An array to track the current key of the recursion. This is used to check whether the current key (part of
272
- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
273
- `disk`).
274
- quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
275
- To configure and manage settings related to quantization, a technique used to compress neural network models
276
- by reducing the precision of the weights and activations, thus making models more efficient in terms of both
277
- storage and computation.
278
- """
279
- modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
280
- model, has_been_replaced = _replace_with_bnb_linear(
281
- model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
282
- )
222
+ new_module.requires_grad_(False)
223
+ model.set_submodule(module_name, new_module)
224
+ has_been_replaced = True
283
225
 
284
226
  if not has_been_replaced:
285
227
  logger.warning(
286
- "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
228
+ "You are loading your model using eetq but no linear modules were found in your model."
287
229
  " Please double check your model architecture, or submit an issue on github if you think this is"
288
230
  " a bug."
289
231
  )
@@ -341,95 +283,43 @@ def _create_accelerate_new_hook(old_hook):
341
283
  return new_hook
342
284
 
343
285
 
344
- def _dequantize_and_replace(
286
+ def dequantize_and_replace(
345
287
  model,
346
- dtype,
347
- modules_to_not_convert=None,
348
- current_key_name=None,
349
288
  quantization_config=None,
350
- has_been_replaced=False,
351
289
  ):
352
290
  """
353
291
  Converts a quantized model into its dequantized original version. The newly converted model will have
354
292
  some performance drop compared to the original model before quantization - use it only for specific usecases
355
293
  such as QLoRA adapters merging.
356
294
 
357
- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
295
+ Returns the converted model.
358
296
  """
359
297
  quant_method = quantization_config.quantization_method()
360
298
 
361
299
  target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
362
300
 
363
- for name, module in model.named_children():
364
- if current_key_name is None:
365
- current_key_name = []
366
- current_key_name.append(name)
367
-
368
- if isinstance(module, target_cls) and name not in modules_to_not_convert:
369
- # Check if the current key is not in the `modules_to_not_convert`
370
- current_key_name_str = ".".join(current_key_name)
371
-
372
- if not any(
373
- (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
374
- ):
301
+ for module_name, module in model.named_modules():
302
+ if isinstance(module, target_cls):
303
+ with init_empty_weights():
375
304
  bias = getattr(module, "bias", None)
376
-
377
- device = module.weight.device
378
- with init_empty_weights():
379
- new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
380
-
381
- if quant_method == "llm_int8":
382
- state = module.state
383
- else:
384
- state = None
385
-
386
- new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state))
387
-
388
- if bias is not None:
389
- new_module.bias = bias
390
-
391
- # Create a new hook and attach it in case we use accelerate
392
- if hasattr(module, "_hf_hook"):
393
- old_hook = module._hf_hook
394
- new_hook = _create_accelerate_new_hook(old_hook)
395
-
396
- remove_hook_from_module(module)
397
- add_hook_to_module(new_module, new_hook)
398
-
399
- new_module.to(device)
400
- model._modules[name] = new_module
401
- has_been_replaced = True
402
- if len(list(module.children())) > 0:
403
- _, has_been_replaced = _dequantize_and_replace(
404
- module,
405
- dtype,
406
- modules_to_not_convert,
407
- current_key_name,
408
- quantization_config,
409
- has_been_replaced=has_been_replaced,
410
- )
411
- # Remove the last key for recursion
412
- current_key_name.pop(-1)
413
- return model, has_been_replaced
414
-
415
-
416
- def dequantize_and_replace(
417
- model,
418
- modules_to_not_convert=None,
419
- quantization_config=None,
420
- ):
421
- model, has_been_replaced = _dequantize_and_replace(
422
- model,
423
- model.dtype,
424
- modules_to_not_convert=modules_to_not_convert,
425
- quantization_config=quantization_config,
426
- )
305
+ new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
306
+ state = module.state if quant_method == "llm_int8" else None
307
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, model.dtype, state))
308
+ if bias is not None:
309
+ new_module.bias = bias
310
+ if hasattr(module, "_hf_hook"):
311
+ old_hook = module._hf_hook
312
+ new_hook = _create_accelerate_new_hook(old_hook)
313
+ remove_hook_from_module(module)
314
+ add_hook_to_module(new_module, new_hook)
315
+ new_module.to(module.weight.device)
316
+ model.set_submodule(module_name, new_module)
317
+ has_been_replaced = True
427
318
 
428
319
  if not has_been_replaced:
429
320
  logger.warning(
430
321
  "For some reason the model has not been properly dequantized. You might see unexpected behavior."
431
322
  )
432
-
433
323
  return model
434
324
 
435
325
 
@@ -437,8 +327,6 @@ def validate_bnb_backend_availability(raise_exception=False):
437
327
  """
438
328
  Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
439
329
  """
440
- import bitsandbytes as bnb
441
-
442
330
  bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
443
331
  available_devices = set(get_available_devices())
444
332
 
@@ -333,6 +333,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict):
333
333
  for name, child in module._modules.items():
334
334
  if child is not None:
335
335
  load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
336
+ child._is_hf_initialized = True
336
337
 
337
338
  load(model_to_load, state_dict, assign_to_params_buffers=False)
338
339