transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -105,8 +105,9 @@ class GenerationConfig(PushToHubMixin):
105
105
  > Parameters that control the length of the output
106
106
 
107
107
  max_length (`int`, *optional*, defaults to 20):
108
- The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
109
- `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
108
+ `max_new_tokens` is recommended for controlling how many tokens the model generates.
109
+ `max_length` remains for backward compatibility.
110
+
110
111
  max_new_tokens (`int`, *optional*):
111
112
  The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
112
113
  min_length (`int`, *optional*, defaults to 0):
@@ -15,12 +15,16 @@
15
15
  from .cache import PagedAttentionCache
16
16
  from .continuous_api import ContinuousBatchingManager, ContinuousMixin
17
17
  from .requests import RequestState, RequestStatus
18
+ from .scheduler import FIFOScheduler, PrefillFirstScheduler, Scheduler
18
19
 
19
20
 
20
21
  __all__ = [
21
22
  "ContinuousBatchingManager",
22
23
  "ContinuousMixin",
24
+ "FIFOScheduler",
23
25
  "PagedAttentionCache",
26
+ "PrefillFirstScheduler",
24
27
  "RequestState",
25
28
  "RequestStatus",
29
+ "Scheduler",
26
30
  ]
@@ -29,7 +29,7 @@ from tqdm import tqdm
29
29
  from tqdm.contrib.logging import logging_redirect_tqdm
30
30
 
31
31
  from ...configuration_utils import PretrainedConfig
32
- from ...generation.configuration_utils import GenerationConfig
32
+ from ...generation.configuration_utils import CompileConfig, GenerationConfig
33
33
  from ...generation.logits_process import LogitsProcessor
34
34
  from ...utils.logging import logging
35
35
  from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@@ -45,17 +45,20 @@ generation goes on, there are two dimensions that change:
45
45
  - the number of keys/values tokens (KV), which grows as the cache does
46
46
 
47
47
  To solve this, we slice along those dimensions to fixed lengths. The size of the slices is controlled by the variables
48
- below: NUM_X_CUDA_GRAPHS means that we create at most NUM_X_CUDA_GRAPHS graphs for the X dimension. So if the maximum
49
- number of queries tokens is 1000, and NUM_Q_CUDA_GRAPHS is 4, we will slice the number of queries token by intervals of
50
- 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
48
+ num_x_padding_intervals: NUM_X_PADDING_INTERVALS means that we create at most NUM_X_PADDING_INTERVALS graphs for the X
49
+ dimension. So if the maximum number of queries tokens is 1000, and NUM_Q_PADDING_INTERVALS is 4, we will slice the
50
+ number of queries token by intervals of 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
51
51
 
52
52
  Smaller slices means more granularity and thus less padding. But since each graph takes up space on the GPU and time to
53
53
  create, we don't want to many graphs. And since the size of the KV dimension is the number of queries tokens plus the
54
54
  number of tokens cached, dimension of KV is usually much larger than the dimension of Q. So we have more granularity
55
55
  for the KV dimension than the query dimension.
56
+
57
+ This variable used to be called NUM_X_CUDA_GRAPHS, but we renamed it to NUM_X_PADDING_INTERVALS because it is used for
58
+ padding in the case of cuda graphs AND torch.compile.
56
59
  """
57
- NUM_Q_CUDA_GRAPHS = 4
58
- NUM_KV_CUDA_GRAPHS = 8
60
+ NUM_Q_PADDING_INTERVALS = 4
61
+ NUM_KV_PADDING_INTERVALS = 8
59
62
 
60
63
 
61
64
  def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
@@ -63,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
63
66
  interval_size = max_value // nb_intervals
64
67
  if interval_size == 0:
65
68
  return max_value
66
- padded = ceil(size / interval_size) * interval_size
69
+ padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
67
70
  return min(padded, max_value)
68
71
 
69
72
 
@@ -188,6 +191,8 @@ class ContinuousBatchProcessor:
188
191
  scheduler: Scheduler,
189
192
  manual_eviction: bool,
190
193
  use_cuda_graph: bool,
194
+ q_padding_intervals: int,
195
+ kv_padding_intervals: int,
191
196
  ) -> None:
192
197
  """Initialize the continuous batch processor.
193
198
 
@@ -221,7 +226,14 @@ class ContinuousBatchProcessor:
221
226
  # Accumulator for batch scheduling
222
227
  self.requests_in_batch: list[RequestState] = []
223
228
  # Cuda graphs for the generation step
229
+ self.q_padding_intervals = q_padding_intervals
230
+ self.kv_padding_intervals = kv_padding_intervals
224
231
  self._graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] | None = {} if use_cuda_graph else None
232
+ # Compile-related arguments
233
+ self.compile_config: CompileConfig | None = getattr(generation_config, "compile_config", None)
234
+ self._forward_process_and_sample_is_compiled = False
235
+
236
+ self._pad_inputs = use_cuda_graph or (self.compile_config is not None and not self.compile_config.dynamic)
225
237
 
226
238
  # Set up metrics collector
227
239
  self.max_batch_tokens = cache.max_batch_tokens
@@ -627,28 +639,39 @@ class ContinuousBatchProcessor:
627
639
  def _generation_step(self, model: nn.Module, logit_processor: LogitsProcessor, do_sample: bool) -> None:
628
640
  """Perform a single generation step."""
629
641
 
630
- # If cuda graphs are disabled, we just use the actual size
642
+ # If a compile config is specified, we compile the forward pass once in a wrapper
643
+ if self.compile_config is not None and not self._forward_process_and_sample_is_compiled:
644
+ self._forward_process_and_sample = torch.compile(
645
+ self._forward_process_and_sample,
646
+ fullgraph=self.compile_config.fullgraph,
647
+ mode=self.compile_config.mode,
648
+ dynamic=self.compile_config.dynamic,
649
+ backend=self.compile_config.backend,
650
+ options=self.compile_config.options,
651
+ )
652
+ self._forward_process_and_sample_is_compiled = True
653
+
654
+ # If inputs are static sized, we find the padded sizes of the queries and keys/values
655
+ if self._pad_inputs:
656
+ padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, self.q_padding_intervals)
657
+ max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
658
+ padded_read_index_size = pad_by_intervals(
659
+ max_read_index_size - self.max_batch_tokens,
660
+ self.cache.num_blocks * self.cache.block_size,
661
+ self.kv_padding_intervals,
662
+ )
663
+ else:
664
+ padded_q, padded_read_index_size = 0, 0
665
+ # Retrieve the model kwargs with or without padding
666
+ batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
667
+
668
+ # If we are not using cuda graphs, we perform the generation step and return
631
669
  if self._graphs is None:
632
- batch_data = self.get_model_kwargs()
633
670
  self._forward_process_and_sample(model, batch_data, logit_processor, do_sample)
634
671
  return None
635
672
 
636
- # Determine the padded size of the queries and keys/values
637
- padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, NUM_Q_CUDA_GRAPHS)
638
-
639
- max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
640
- padded_read_index_size = pad_by_intervals(
641
- max_read_index_size - self.max_batch_tokens,
642
- self.cache.num_blocks * self.cache.block_size,
643
- NUM_KV_CUDA_GRAPHS,
644
- )
645
-
646
- # Get the batch data and the associated graph
647
- batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
648
-
649
- graph = self._graphs.get((padded_q, padded_read_index_size))
650
-
651
673
  # If we have a graph that fits, we replay it
674
+ graph = self._graphs.get((padded_q, padded_read_index_size))
652
675
  if graph is not None:
653
676
  graph.replay()
654
677
  return None
@@ -673,7 +696,6 @@ class ContinuousBatchProcessor:
673
696
  ) -> None:
674
697
  """This function performs the forward pass, logits processing, and sampling; which are broken down into smaller
675
698
  function to be easier to trace with OpenTelemetry."""
676
- # with torch.no_grad():
677
699
  logits = self._model_forward(model, batch_data)
678
700
  # if self.log_prob_generation: batch_processor.output_probs.copy_(logits) # TODO
679
701
  probs = self._process_logit(batch_data, logits, logit_processor)
@@ -691,6 +713,7 @@ class ContinuousBatchProcessor:
691
713
  # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
692
714
  # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
693
715
  batch_size, seq_len, vocab_size = logits.shape
716
+ # NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
694
717
  logits_2d = logits.view(batch_size * seq_len, vocab_size)
695
718
  input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
696
719
  # Process with 2D tensors
@@ -727,8 +750,8 @@ class ContinuousBatchingManager:
727
750
  generation_config: GenerationConfig,
728
751
  manual_eviction: bool = False,
729
752
  max_queue_size: int = 0,
730
- num_q_cuda_graphs: int = 0,
731
- num_kv_cuda_graphs: int = 0,
753
+ num_q_padding_intervals: int = 0,
754
+ num_kv_padding_intervals: int = 0,
732
755
  allow_prefix_sharing: bool = True,
733
756
  ) -> None:
734
757
  """Initialize the continuous batching manager.
@@ -737,19 +760,13 @@ class ContinuousBatchingManager:
737
760
  model: The language model for generation
738
761
  generation_config: Configuration for generation parameters
739
762
  max_queue_size: Maximum size of the request queue (0 = unlimited)
740
- num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
741
- num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
763
+ num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
764
+ num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
742
765
  allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
743
766
  """
767
+ # Reloade paged version if necessary
744
768
  if "paged|" not in model.config._attn_implementation:
745
- attn_implementation = f"paged|{model.config._attn_implementation}"
746
- model.config._attn_implementation = attn_implementation
747
-
748
- # lazy loading flash attention including kernel variations
749
- if "flash" in attn_implementation:
750
- from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
751
-
752
- lazy_import_paged_flash_attention(attn_implementation)
769
+ model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
753
770
 
754
771
  self.model = model.eval()
755
772
  generation_config = model.generation_config if generation_config is None else generation_config
@@ -764,38 +781,69 @@ class ContinuousBatchingManager:
764
781
  self.model.generation_config.top_p = None
765
782
  self.do_sample = getattr(generation_config, "do_sample", True)
766
783
  self.logit_processor = self.model._get_logits_processor(generation_config)
767
- use_cuda_graph: bool | None = getattr(generation_config, "use_cuda_graph", None)
768
784
  self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
769
785
  self.manual_eviction = manual_eviction
770
786
  self.batch_processor: ContinuousBatchProcessor | None = None
771
-
772
787
  self._allow_prefix_sharing = allow_prefix_sharing
773
788
 
774
- # If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
775
- if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
776
- self.use_cuda_graph = True
777
- # If use_cuda_graph is specified, we follow the user's choice
778
- elif use_cuda_graph is not None:
779
- self.use_cuda_graph = use_cuda_graph
780
- # If the use of cuda graphs is not specified, we follow the user's choice, otherwise we have a default heuristic
781
- else:
782
- # Attention implementations where an attention mask is needed suffer a lot more from the padding associated
783
- # with cuda graphs, so default is to turn cuda graphs off for those implementations
784
- self.use_cuda_graph = not attn_mask_is_needed(self.model.config)
785
- logger.warning(
786
- f"No behavior specified for use_cuda_graph, defaulting to {self.use_cuda_graph = } because "
787
- f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
788
- "they can improve performances."
789
- )
789
+ self.use_cuda_graph = self._decide_use_cuda_graphs(
790
+ use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
791
+ num_q_padding_intervals=num_q_padding_intervals,
792
+ num_kv_padding_intervals=num_kv_padding_intervals,
793
+ compile_config=getattr(generation_config, "compile_config", None),
794
+ )
790
795
 
791
- # If cuda graphs are activated, we set the number of cuda graphs for Q and KV if not specified
792
- if self.use_cuda_graph:
793
- self.num_q_cuda_graphs = num_q_cuda_graphs if num_q_cuda_graphs > 0 else NUM_Q_CUDA_GRAPHS
794
- self.num_kv_cuda_graphs = num_kv_cuda_graphs if num_kv_cuda_graphs > 0 else NUM_KV_CUDA_GRAPHS
796
+ # We set the number of padding intervals for Q and KV
797
+ self.q_padding_intervals = num_q_padding_intervals if num_q_padding_intervals > 0 else NUM_Q_PADDING_INTERVALS
798
+ self.kv_padding_intervals = (
799
+ num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
800
+ )
795
801
 
796
802
  if self.log_prob_generation:
797
803
  raise NotImplementedError("log_prob_generation is not supported yet")
798
804
 
805
+ def _decide_use_cuda_graphs(
806
+ self,
807
+ use_cuda_graph: bool | None,
808
+ num_q_padding_intervals: int,
809
+ num_kv_padding_intervals: int,
810
+ compile_config: CompileConfig | None,
811
+ ) -> bool:
812
+ """Returns whether or not to use cuda graphs for continuous batching, depending on the following criteria:
813
+ - (use_cuda_graph) which is the user choice
814
+ - (num_q_padding_intervals) or (num_kv_padding_intervals) which is used to pad inputs: if it was specified by
815
+ the user, it's probable they want to use cuda graphs so inputs need to be padded
816
+ - (compile_config): if compile is on, turn on cuda graphs unless the compile mode uses its own cudagraphs
817
+ If none of the above criteria are met, we use a default heuristic based on the attention implementation: we turn
818
+ on cuda graphs if and only if no attention mask is needed.
819
+ """
820
+ # If use_cuda_graph is specified, we follow the user's choice
821
+ if use_cuda_graph is not None:
822
+ return use_cuda_graph
823
+ # If a number of padding intervals was specified for either Q or KV, we activate cuda graphs
824
+ if num_q_padding_intervals > 0 or num_kv_padding_intervals > 0:
825
+ return True
826
+ # If a compile config was found, turn off cuda graphs if the compile config already uses them
827
+ if compile_config is not None:
828
+ options = torch._inductor.list_mode_options().get(compile_config.mode, compile_config.options)
829
+ compile_uses_cudagraphs = options.get("triton.cudagraphs", False)
830
+ if compile_uses_cudagraphs:
831
+ logger.warning(
832
+ f"Compile config {compile_config.mode = } uses cudagraphs, which usually does not work well with "
833
+ "continuous batching. We recommend using mode 'default' or 'max-autotune-no-cudagraphs' instead."
834
+ )
835
+ return not compile_uses_cudagraphs # TODO: should this also match the dynamic shapes?
836
+ # Otherwise we have a default heuristic based on the attention implementation:
837
+ # attention implementations where an attention mask is needed suffer a lot more from the padding associated
838
+ # with cuda graphs, so default is to turn cuda graphs off for those implementations
839
+ use_cuda_graph = not attn_mask_is_needed(self.model.config)
840
+ logger.warning(
841
+ f"No behavior specified for use_cuda_graph, defaulting to {use_cuda_graph = } because "
842
+ f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
843
+ "they can improve performances."
844
+ )
845
+ return use_cuda_graph
846
+
799
847
  @traced
800
848
  def start(self) -> None:
801
849
  """Start the background generation thread."""
@@ -822,7 +870,7 @@ class ContinuousBatchingManager:
822
870
  logger.warning("\nBatch processor was not initialized.")
823
871
  else:
824
872
  if self.batch_processor.cache.use_prefix_sharing:
825
- logger.warning(
873
+ logger.info(
826
874
  f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
827
875
  )
828
876
 
@@ -999,6 +1047,8 @@ class ContinuousBatchingManager:
999
1047
  scheduler=scheduler(paged_attention_cache, self.manual_eviction),
1000
1048
  manual_eviction=self.manual_eviction,
1001
1049
  use_cuda_graph=self.use_cuda_graph,
1050
+ q_padding_intervals=self.q_padding_intervals,
1051
+ kv_padding_intervals=self.kv_padding_intervals,
1002
1052
  )
1003
1053
  self.batch_processor = batch_processor
1004
1054
  self.current_batch = 0
@@ -1024,12 +1074,15 @@ class ContinuousBatchingManager:
1024
1074
  # Debug logging of the current memory usage
1025
1075
  if logger.level <= logging.DEBUG:
1026
1076
  device, total, reserved, allocated = get_device_and_memory_breakdown()
1027
- logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
1077
+ available_memory = total - max(allocated, reserved)
1078
+ logger.debug(
1079
+ f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}, Available: {available_memory}"
1080
+ )
1028
1081
 
1029
1082
  self._generation_step()
1030
1083
 
1031
1084
  if torch.cuda.is_available():
1032
- torch.cuda.synchronize()
1085
+ torch.cuda.synchronize() # FIXME: why is this needed?
1033
1086
  # Processor updates the batch after generation step is truly over
1034
1087
  batch_processor.update_batch()
1035
1088
 
@@ -1099,18 +1152,19 @@ class ContinuousMixin:
1099
1152
  generation_config: GenerationConfig | None = None,
1100
1153
  manual_eviction: bool = False,
1101
1154
  max_queue_size: int = 0,
1102
- num_q_cuda_graphs: int = 0,
1103
- num_kv_cuda_graphs: int = 0,
1155
+ num_q_padding_intervals: int = 0,
1156
+ num_kv_padding_intervals: int = 0,
1104
1157
  allow_prefix_sharing: bool = True,
1105
1158
  ) -> ContinuousBatchingManager:
1106
1159
  """Initialize a manager for continuous batching inference.
1107
1160
 
1108
1161
  Args:
1109
- generation_config: Custom generation configuration
1162
+ generation_config: An optional generation configuration, which may contain a CompileConfig object
1110
1163
  manual_eviction: Whether to manually evict requests from the cache
1111
1164
  max_queue_size: Maximum size of the input request queue
1112
- num_q_cuda_graphs: Number of CUDA graphs to use for the query dimension
1113
- num_kv_cuda_graphs: Number of CUDA graphs to use for the keys/values dimension
1165
+ num_q_padding_intervals: Number of intervals used to pad the query dimension
1166
+ num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1167
+ allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
1114
1168
 
1115
1169
  Returns:
1116
1170
  `ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
@@ -1132,8 +1186,8 @@ class ContinuousMixin:
1132
1186
  generation_config=gen_config,
1133
1187
  manual_eviction=manual_eviction,
1134
1188
  max_queue_size=max_queue_size,
1135
- num_q_cuda_graphs=num_q_cuda_graphs,
1136
- num_kv_cuda_graphs=num_kv_cuda_graphs,
1189
+ num_q_padding_intervals=num_q_padding_intervals,
1190
+ num_kv_padding_intervals=num_kv_padding_intervals,
1137
1191
  allow_prefix_sharing=allow_prefix_sharing,
1138
1192
  )
1139
1193
 
@@ -1144,11 +1198,11 @@ class ContinuousMixin:
1144
1198
  self,
1145
1199
  inputs: list[list[int]],
1146
1200
  generation_config: GenerationConfig | None = None,
1147
- progress_bar: bool = True,
1148
- num_q_cuda_graphs: int = 0,
1149
- num_kv_cuda_graphs: int = 0,
1201
+ num_q_padding_intervals: int = 0,
1202
+ num_kv_padding_intervals: int = 0,
1150
1203
  allow_prefix_sharing: bool = True,
1151
1204
  record_timestamps: bool = False,
1205
+ progress_bar: bool = True,
1152
1206
  **kwargs,
1153
1207
  ) -> dict[str, GenerationOutput]:
1154
1208
  """Generate sequences for a batch of prompts using continuous batching.
@@ -1156,14 +1210,15 @@ class ContinuousMixin:
1156
1210
  Args:
1157
1211
  inputs: List of input token sequences (prompts)
1158
1212
  generation_config: Optional generation configuration
1159
- num_q_cuda_graphs: Number of CUDA graphs to use for the query dimension
1160
- num_kv_cuda_graphs: Number of CUDA graphs to use for the keys/values dimension
1213
+ num_q_padding_intervals: Number of intervals used to pad the query dimension
1214
+ num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1215
+ allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
1216
+ record_timestamps: If set to true, the requests will have a timestamp for each token generated
1217
+ progress_bar: If set to true, a progress bar will be displayed
1161
1218
  **kwargs: Additional generation parameters
1162
1219
 
1163
1220
  Returns:
1164
- `list[list[int]]`: A list containing the generated sequences (including prompt tokens
1165
- if not handled otherwise) for each input prompt, in the same order.
1166
- Returns an empty list `[]` for requests that failed.
1221
+ `dict[str, GenerationOutput]`: a dictionary of request ids to GenerationOutput objects
1167
1222
  """
1168
1223
  if not inputs:
1169
1224
  return {}
@@ -1177,8 +1232,8 @@ class ContinuousMixin:
1177
1232
  with (
1178
1233
  self.continuous_batching_context_manager(
1179
1234
  generation_config=generation_config,
1180
- num_q_cuda_graphs=num_q_cuda_graphs,
1181
- num_kv_cuda_graphs=num_kv_cuda_graphs,
1235
+ num_q_cuda_graphs=num_q_padding_intervals,
1236
+ num_kv_cuda_graphs=num_kv_padding_intervals,
1182
1237
  allow_prefix_sharing=allow_prefix_sharing,
1183
1238
  block=True,
1184
1239
  timeout=5,
@@ -18,7 +18,7 @@ import os
18
18
  from typing import Any, Optional, TypeVar, Union
19
19
 
20
20
  import numpy as np
21
- from huggingface_hub import create_repo
21
+ from huggingface_hub import create_repo, is_offline_mode
22
22
 
23
23
  from .dynamic_module_utils import custom_object_save
24
24
  from .feature_extraction_utils import BatchFeature as BaseBatchFeature
@@ -28,7 +28,6 @@ from .utils import (
28
28
  PROCESSOR_NAME,
29
29
  PushToHubMixin,
30
30
  copy_func,
31
- is_offline_mode,
32
31
  logging,
33
32
  safe_load_json_file,
34
33
  )
@@ -19,7 +19,6 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_availa
19
19
  _import_structure = {
20
20
  "aqlm": ["replace_with_aqlm_linear"],
21
21
  "awq": [
22
- "fuse_awq_modules",
23
22
  "post_init_awq_exllama_modules",
24
23
  "post_init_awq_ipex_modules",
25
24
  "replace_quantization_scales",
@@ -54,6 +53,7 @@ _import_structure = {
54
53
  "finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
55
54
  "fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
56
55
  "ggml": [
56
+ "GGUF_CONFIG_DEFAULTS_MAPPING",
57
57
  "GGUF_CONFIG_MAPPING",
58
58
  "GGUF_TOKENIZER_MAPPING",
59
59
  "_gguf_parse_value",
@@ -73,6 +73,7 @@ _import_structure = {
73
73
  "replace_kernel_forward_from_hub",
74
74
  "use_kernel_forward_from_hub",
75
75
  "use_kernel_func_from_hub",
76
+ "use_kernelized_func",
76
77
  ],
77
78
  "integration_utils": [
78
79
  "INTEGRATION_TO_CALLBACK",
@@ -165,7 +166,6 @@ else:
165
166
  if TYPE_CHECKING:
166
167
  from .aqlm import replace_with_aqlm_linear
167
168
  from .awq import (
168
- fuse_awq_modules,
169
169
  post_init_awq_exllama_modules,
170
170
  post_init_awq_ipex_modules,
171
171
  replace_quantization_scales,
@@ -200,6 +200,7 @@ if TYPE_CHECKING:
200
200
  from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
201
201
  from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
202
202
  from .ggml import (
203
+ GGUF_CONFIG_DEFAULTS_MAPPING,
203
204
  GGUF_CONFIG_MAPPING,
204
205
  GGUF_TOKENIZER_MAPPING,
205
206
  _gguf_parse_value,
@@ -214,6 +215,7 @@ if TYPE_CHECKING:
214
215
  replace_kernel_forward_from_hub,
215
216
  use_kernel_forward_from_hub,
216
217
  use_kernel_func_from_hub,
218
+ use_kernelized_func,
217
219
  )
218
220
  from .integration_utils import (
219
221
  INTEGRATION_TO_CALLBACK,
@@ -392,6 +392,15 @@ def _get_device_map(
392
392
  )
393
393
  else:
394
394
  inferred_max_memory = get_max_memory(max_memory)
395
+
396
+ # If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
397
+ # This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
398
+ # especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
399
+ # the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
400
+ # if we were in-between, as otherwise we blow-up cpu memory
401
+ if max_memory is None:
402
+ inferred_max_memory["cpu"] *= 0.90
403
+
395
404
  if hf_quantizer is not None:
396
405
  inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
397
406
 
@@ -466,10 +475,10 @@ def expand_device_map(device_map, param_names):
466
475
 
467
476
 
468
477
  def accelerate_disk_offload(
478
+ model: "PreTrainedModel",
469
479
  disk_offload_folder: str | None,
470
480
  checkpoint_files: list[str] | None,
471
481
  device_map: dict,
472
- expected_keys: list[str],
473
482
  sharded_metadata: dict | None,
474
483
  dtype: torch.dtype | None,
475
484
  weight_mapping=None,
@@ -493,7 +502,8 @@ def accelerate_disk_offload(
493
502
  # In this case, the offload index is simply the existing safetensors (except if using custom weight loading
494
503
  # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
495
504
  if is_offloaded_safetensors:
496
- param_device_map = expand_device_map(device_map, expected_keys)
505
+ meta_state_dict = model.state_dict()
506
+ param_device_map = expand_device_map(device_map, meta_state_dict.keys())
497
507
  str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
498
508
  if sharded_metadata is None:
499
509
  weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
@@ -502,7 +512,9 @@ def accelerate_disk_offload(
502
512
  weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
503
513
 
504
514
  # Update the weight names according to the `weight_mapping`
505
- weight_renaming_map = {rename_source_key(k, renamings, [])[0]: k for k in weight_map}
515
+ weight_renaming_map = {
516
+ rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map
517
+ }
506
518
 
507
519
  # Prepare the index using existing safetensors files
508
520
  disk_offload_index = {
@@ -13,88 +13,60 @@
13
13
  # limitations under the License.
14
14
  "AQLM (Additive Quantization of Language Model) integration file"
15
15
 
16
- from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_aqlm_available, is_torch_available
16
+ from ..quantizers.quantizers_utils import should_convert_module
17
+ from ..utils import is_accelerate_available, is_torch_available, logging
17
18
 
18
19
 
20
+ if is_accelerate_available():
21
+ from accelerate import init_empty_weights
22
+
19
23
  if is_torch_available():
20
24
  import torch.nn as nn
21
25
 
26
+ logger = logging.get_logger(__name__)
22
27
 
23
- def replace_with_aqlm_linear(
24
- model,
25
- quantization_config=None,
26
- linear_weights_not_to_quantize=None,
27
- current_key_name=None,
28
- has_been_replaced=False,
29
- ):
28
+
29
+ def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
30
30
  """
31
31
  Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
32
- `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
33
- conversion has been successful or not.
34
32
 
35
33
  Args:
36
34
  model (`torch.nn.Module`):
37
35
  The model to convert, can be any `torch.nn.Module` instance.
38
- quantization_config (`AqlmConfig`):
39
- The quantization config object that contains the quantization parameters.
40
- linear_weights_not_to_quantize (`list[str]`, *optional*):
36
+ modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
41
37
  A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
42
38
  converted.
43
- current_key_name (`list`, *optional*):
44
- A list that contains the current key name. This is used for recursion and should not be passed by the user.
45
- has_been_replaced (`bool`, *optional*):
46
- A boolean that indicates if the conversion has been successful or not. This is used for recursion and
47
- should not be passed by the user.
39
+ quantization_config (`AqlmConfig`):
40
+ The quantization config object that contains the quantization parameters.
48
41
  """
49
- if not is_aqlm_available():
50
- raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
51
-
52
- if not is_accelerate_available():
53
- raise ValueError(
54
- f"AQLM requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
55
- )
56
-
57
- if linear_weights_not_to_quantize is None:
58
- linear_weights_not_to_quantize = []
59
-
60
- from accelerate import init_empty_weights
61
42
  from aqlm import QuantizedLinear
62
43
 
63
- for name, module in model.named_children():
64
- if current_key_name is None:
65
- current_key_name = []
66
- current_key_name.append(name)
67
-
68
- if isinstance(module, nn.Linear):
69
- # Check if the current key is not in the `linear_weights_not_to_quantize`
70
- if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
71
- with init_empty_weights():
72
- in_features = module.in_features
73
- out_features = module.out_features
44
+ has_been_replaced = False
45
+ # we need this to correctly materialize the weights during quantization
46
+ for module_name, module in model.named_modules():
47
+ if not should_convert_module(module_name, modules_to_not_convert):
48
+ continue
49
+ with init_empty_weights():
50
+ if isinstance(module, nn.Linear):
51
+ new_module = QuantizedLinear(
52
+ module.in_features,
53
+ module.out_features,
54
+ bias=module.bias is not None,
55
+ in_group_size=quantization_config.in_group_size,
56
+ out_group_size=quantization_config.out_group_size,
57
+ num_codebooks=quantization_config.num_codebooks,
58
+ nbits_per_codebook=quantization_config.nbits_per_codebook,
59
+ )
60
+ new_module.source_cls = type(module)
61
+ new_module.requires_grad_(False)
62
+ model.set_submodule(module_name, new_module)
63
+ has_been_replaced = True
74
64
 
75
- model._modules[name] = QuantizedLinear(
76
- in_features,
77
- out_features,
78
- bias=module.bias is not None,
79
- in_group_size=quantization_config.in_group_size,
80
- out_group_size=quantization_config.out_group_size,
81
- num_codebooks=quantization_config.num_codebooks,
82
- nbits_per_codebook=quantization_config.nbits_per_codebook,
83
- )
84
- has_been_replaced = True
65
+ if not has_been_replaced:
66
+ logger.warning(
67
+ "You are loading your model using eetq but no linear modules were found in your model."
68
+ " Please double check your model architecture, or submit an issue on github if you think this is"
69
+ " a bug."
70
+ )
85
71
 
86
- # Store the module class in case we need to transpose the weight later
87
- model._modules[name].source_cls = type(module)
88
- # Force requires grad to False to avoid unexpected errors
89
- model._modules[name].requires_grad_(False)
90
- if len(list(module.children())) > 0:
91
- _, has_been_replaced = replace_with_aqlm_linear(
92
- module,
93
- quantization_config=quantization_config,
94
- linear_weights_not_to_quantize=linear_weights_not_to_quantize,
95
- current_key_name=current_key_name,
96
- has_been_replaced=has_been_replaced,
97
- )
98
- # Remove the last key for recursion
99
- current_key_name.pop(-1)
100
- return model, has_been_replaced
72
+ return model