transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from tqdm import tqdm
29
29
  from tqdm.contrib.logging import logging_redirect_tqdm
30
30
 
31
31
  from ...configuration_utils import PretrainedConfig
32
- from ...generation.configuration_utils import GenerationConfig
32
+ from ...generation.configuration_utils import CompileConfig, GenerationConfig
33
33
  from ...generation.logits_process import LogitsProcessor
34
34
  from ...utils.logging import logging
35
35
  from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@@ -45,17 +45,20 @@ generation goes on, there are two dimensions that change:
45
45
  - the number of keys/values tokens (KV), which grows as the cache does
46
46
 
47
47
  To solve this, we slice along those dimensions to fixed lengths. The size of the slices is controlled by the variables
48
- below: NUM_X_CUDA_GRAPHS means that we create at most NUM_X_CUDA_GRAPHS graphs for the X dimension. So if the maximum
49
- number of queries tokens is 1000, and NUM_Q_CUDA_GRAPHS is 4, we will slice the number of queries token by intervals of
50
- 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
48
+ num_x_padding_intervals: NUM_X_PADDING_INTERVALS means that we create at most NUM_X_PADDING_INTERVALS graphs for the X
49
+ dimension. So if the maximum number of queries tokens is 1000, and NUM_Q_PADDING_INTERVALS is 4, we will slice the
50
+ number of queries token by intervals of 1000 / 4 = 250 tokens, ie. to 250, 500, 750 or 1000 queries tokens.
51
51
 
52
52
  Smaller slices means more granularity and thus less padding. But since each graph takes up space on the GPU and time to
53
53
  create, we don't want to many graphs. And since the size of the KV dimension is the number of queries tokens plus the
54
54
  number of tokens cached, dimension of KV is usually much larger than the dimension of Q. So we have more granularity
55
55
  for the KV dimension than the query dimension.
56
+
57
+ This variable used to be called NUM_X_CUDA_GRAPHS, but we renamed it to NUM_X_PADDING_INTERVALS because it is used for
58
+ padding in the case of cuda graphs AND torch.compile.
56
59
  """
57
- NUM_Q_CUDA_GRAPHS = 4
58
- NUM_KV_CUDA_GRAPHS = 8
60
+ NUM_Q_PADDING_INTERVALS = 4
61
+ NUM_KV_PADDING_INTERVALS = 8
59
62
 
60
63
 
61
64
  def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
@@ -63,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
63
66
  interval_size = max_value // nb_intervals
64
67
  if interval_size == 0:
65
68
  return max_value
66
- padded = ceil(size / interval_size) * interval_size
69
+ padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
67
70
  return min(padded, max_value)
68
71
 
69
72
 
@@ -188,6 +191,8 @@ class ContinuousBatchProcessor:
188
191
  scheduler: Scheduler,
189
192
  manual_eviction: bool,
190
193
  use_cuda_graph: bool,
194
+ q_padding_intervals: int,
195
+ kv_padding_intervals: int,
191
196
  ) -> None:
192
197
  """Initialize the continuous batch processor.
193
198
 
@@ -221,7 +226,14 @@ class ContinuousBatchProcessor:
221
226
  # Accumulator for batch scheduling
222
227
  self.requests_in_batch: list[RequestState] = []
223
228
  # Cuda graphs for the generation step
229
+ self.q_padding_intervals = q_padding_intervals
230
+ self.kv_padding_intervals = kv_padding_intervals
224
231
  self._graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] | None = {} if use_cuda_graph else None
232
+ # Compile-related arguments
233
+ self.compile_config: CompileConfig | None = getattr(generation_config, "compile_config", None)
234
+ self._forward_process_and_sample_is_compiled = False
235
+
236
+ self._pad_inputs = use_cuda_graph or (self.compile_config is not None and not self.compile_config.dynamic)
225
237
 
226
238
  # Set up metrics collector
227
239
  self.max_batch_tokens = cache.max_batch_tokens
@@ -247,7 +259,7 @@ class ContinuousBatchProcessor:
247
259
  self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
248
260
  self.max_seqlen_q = 0
249
261
  self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
250
- self.output_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
262
+ self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
251
263
 
252
264
  # For some kwargs, we have a dict of tensors with as many items as there are attention types
253
265
  layer_types = getattr(self.config, "layer_types", None)
@@ -299,7 +311,7 @@ class ContinuousBatchProcessor:
299
311
  self.cumulative_seqlens_q[: b_size + 1].zero_()
300
312
  self.max_seqlen_q = 0
301
313
  self.logits_indices[:q_len].fill_(-1)
302
- self.output_ids[:, :q_len].fill_(-1)
314
+ self.output_ids[:q_len].fill_(-1)
303
315
 
304
316
  # Reset the attributes that are either tensors or dict of tensors
305
317
  for layer_type in self.cumulative_seqlens_k:
@@ -435,7 +447,7 @@ class ContinuousBatchProcessor:
435
447
  self.metrics.record_batch_metrics(self.requests_in_batch)
436
448
 
437
449
  # Reset the static tensors used for storage
438
- self.reset_static_tensors() # TODO: this might be unnecessary
450
+ self.reset_static_tensors() # FIXME: why does this make the generation faster?
439
451
 
440
452
  # Prepare accumulators
441
453
  self.actual_query_length = 0
@@ -545,13 +557,10 @@ class ContinuousBatchProcessor:
545
557
  self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
546
558
 
547
559
  @traced
548
- def _sync(self) -> list[int]:
549
- if self.output_ids is not None:
550
- try:
551
- return self.output_ids.tolist()[0]
552
- except Exception:
553
- return [0, 1]
554
- return [0, 0]
560
+ def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
561
+ indices = self.logits_indices[:num_new_tokens]
562
+ new_tokens = self.output_ids[indices]
563
+ return new_tokens.tolist()
555
564
 
556
565
  @traced
557
566
  def _maybe_send_output(self, state: RequestState) -> None:
@@ -562,29 +571,56 @@ class ContinuousBatchProcessor:
562
571
  @traced
563
572
  def update_batch(self) -> None:
564
573
  """Update request states based on generated tokens."""
565
- out_tokens = self._sync()
566
- for i, state in enumerate(self.requests_in_batch):
574
+ new_tokens = self._get_new_tokens(len(self.requests_in_batch))
575
+ current_logits_index = 0
576
+ for state in self.requests_in_batch:
567
577
  # If the request has no remaining prompt ids, it means prefill has already ended or just finished
568
578
  if len(state.remaining_prefill_tokens) == 0:
569
- self.metrics.record_ttft_metric(state.created_time, state.request_id)
570
- state.status = RequestStatus.DECODING
571
- token = out_tokens[self.logits_indices[i]]
579
+ # If there are no generated tokens yet, it means prefill just ended
580
+ if state.generated_len() == 0:
581
+ self.metrics.record_ttft_metric(state.created_time, state.request_id)
582
+ state.status = RequestStatus.DECODING
583
+
584
+ token = new_tokens[current_logits_index]
572
585
  state.tokens_to_process = [token]
586
+ current_logits_index += 1
587
+
573
588
  # Update the request and stop if it is complete
574
589
  is_finished = state.update_and_check_completion(token)
575
590
  # We mark the completed blocks as such
576
- self.cache.mark_blocks_as_complete(state)
591
+ self.cache.mark_shareable_blocks_as_complete(state)
577
592
  if is_finished:
578
593
  self.metrics.record_request_completion(state.created_time, state.request_id)
579
594
  self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
580
595
  self._maybe_send_output(state)
581
596
  # Otherwise, the request is still prefilling, but the prefill has been split
582
597
  elif state.status == RequestStatus.PREFILLING_SPLIT:
583
- self.cache.mark_blocks_as_complete(state)
598
+ self.cache.mark_shareable_blocks_as_complete(state)
584
599
  state.status = RequestStatus.SPLIT_PENDING_REMAINDER
585
600
  else:
586
601
  raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
587
602
 
603
+ # If some requests need to be forked, we do it now
604
+ copy_source, copy_destination = [], []
605
+ while self.scheduler._requests_to_fork:
606
+ # Get the number of children and reset it so it's not forked again
607
+ state = self.scheduler._requests_to_fork.pop()
608
+ num_children = state.num_children
609
+ state.num_children = 0
610
+ # Create the new request and add them to the scheduler
611
+ new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)]
612
+ for new_request_id in new_request_ids:
613
+ self.scheduler.active_requests[new_request_id] = state.fork(new_request_id)
614
+ # Fork the cache
615
+ copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids)
616
+ copy_source.extend(copy_src)
617
+ copy_destination.extend(copy_dst)
618
+ # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything
619
+
620
+ # The copy induced by the fork is done in one go (if it's even needed)
621
+ if copy_source:
622
+ self.cache.copy_cache(copy_source, copy_destination)
623
+
588
624
  if self.cache.get_num_free_blocks() == 0:
589
625
  raise ValueError("No more free blocks")
590
626
 
@@ -627,28 +663,39 @@ class ContinuousBatchProcessor:
627
663
  def _generation_step(self, model: nn.Module, logit_processor: LogitsProcessor, do_sample: bool) -> None:
628
664
  """Perform a single generation step."""
629
665
 
630
- # If cuda graphs are disabled, we just use the actual size
666
+ # If a compile config is specified, we compile the forward pass once in a wrapper
667
+ if self.compile_config is not None and not self._forward_process_and_sample_is_compiled:
668
+ self._forward_process_and_sample = torch.compile(
669
+ self._forward_process_and_sample,
670
+ fullgraph=self.compile_config.fullgraph,
671
+ mode=self.compile_config.mode,
672
+ dynamic=self.compile_config.dynamic,
673
+ backend=self.compile_config.backend,
674
+ options=self.compile_config.options,
675
+ )
676
+ self._forward_process_and_sample_is_compiled = True
677
+
678
+ # If inputs are static sized, we find the padded sizes of the queries and keys/values
679
+ if self._pad_inputs:
680
+ padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, self.q_padding_intervals)
681
+ max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
682
+ padded_read_index_size = pad_by_intervals(
683
+ max_read_index_size - self.max_batch_tokens,
684
+ self.cache.num_blocks * self.cache.block_size,
685
+ self.kv_padding_intervals,
686
+ )
687
+ else:
688
+ padded_q, padded_read_index_size = 0, 0
689
+ # Retrieve the model kwargs with or without padding
690
+ batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
691
+
692
+ # If we are not using cuda graphs, we perform the generation step and return
631
693
  if self._graphs is None:
632
- batch_data = self.get_model_kwargs()
633
694
  self._forward_process_and_sample(model, batch_data, logit_processor, do_sample)
634
695
  return None
635
696
 
636
- # Determine the padded size of the queries and keys/values
637
- padded_q = pad_by_intervals(self.actual_query_length, self.max_batch_tokens, NUM_Q_CUDA_GRAPHS)
638
-
639
- max_read_index_size = max(self.actual_index_sizes[i][0] for i in range(self.cache.num_groups))
640
- padded_read_index_size = pad_by_intervals(
641
- max_read_index_size - self.max_batch_tokens,
642
- self.cache.num_blocks * self.cache.block_size,
643
- NUM_KV_CUDA_GRAPHS,
644
- )
645
-
646
- # Get the batch data and the associated graph
647
- batch_data = self.get_model_kwargs(padded_q, padded_read_index_size)
648
-
649
- graph = self._graphs.get((padded_q, padded_read_index_size))
650
-
651
697
  # If we have a graph that fits, we replay it
698
+ graph = self._graphs.get((padded_q, padded_read_index_size))
652
699
  if graph is not None:
653
700
  graph.replay()
654
701
  return None
@@ -673,7 +720,6 @@ class ContinuousBatchProcessor:
673
720
  ) -> None:
674
721
  """This function performs the forward pass, logits processing, and sampling; which are broken down into smaller
675
722
  function to be easier to trace with OpenTelemetry."""
676
- # with torch.no_grad():
677
723
  logits = self._model_forward(model, batch_data)
678
724
  # if self.log_prob_generation: batch_processor.output_probs.copy_(logits) # TODO
679
725
  probs = self._process_logit(batch_data, logits, logit_processor)
@@ -691,6 +737,7 @@ class ContinuousBatchProcessor:
691
737
  # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
692
738
  # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
693
739
  batch_size, seq_len, vocab_size = logits.shape
740
+ # NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
694
741
  logits_2d = logits.view(batch_size * seq_len, vocab_size)
695
742
  input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
696
743
  # Process with 2D tensors
@@ -704,12 +751,11 @@ class ContinuousBatchProcessor:
704
751
  probs = nn.functional.softmax(probs, dim=-1)
705
752
  # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
706
753
  next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
707
- # Add batch dimension back to match argmax output
708
- next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
709
754
  else:
710
- next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
711
- tokens = next_tokens.size(1) # Get seq_len dimension
712
- self.output_ids[:, :tokens].copy_(next_tokens)
755
+ next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
756
+ next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
757
+ tokens = next_tokens.size(0) # Get seq_len dimension
758
+ self.output_ids[:tokens].copy_(next_tokens)
713
759
 
714
760
 
715
761
  # Manager Class (User Interface)
@@ -727,9 +773,9 @@ class ContinuousBatchingManager:
727
773
  generation_config: GenerationConfig,
728
774
  manual_eviction: bool = False,
729
775
  max_queue_size: int = 0,
730
- num_q_cuda_graphs: int = 0,
731
- num_kv_cuda_graphs: int = 0,
732
- allow_prefix_sharing: bool = True,
776
+ num_q_padding_intervals: int = 0,
777
+ num_kv_padding_intervals: int = 0,
778
+ allow_block_sharing: bool = True,
733
779
  ) -> None:
734
780
  """Initialize the continuous batching manager.
735
781
 
@@ -737,65 +783,98 @@ class ContinuousBatchingManager:
737
783
  model: The language model for generation
738
784
  generation_config: Configuration for generation parameters
739
785
  max_queue_size: Maximum size of the request queue (0 = unlimited)
740
- num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
741
- num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
742
- allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
786
+ num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
787
+ num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
788
+ allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers
743
789
  """
790
+ # Reload paged version of the attention implementation if necessary
744
791
  if "paged|" not in model.config._attn_implementation:
745
- attn_implementation = f"paged|{model.config._attn_implementation}"
746
- model.config._attn_implementation = attn_implementation
747
-
748
- # lazy loading flash attention including kernel variations
749
- if "flash" in attn_implementation:
750
- from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
751
-
752
- lazy_import_paged_flash_attention(attn_implementation)
792
+ model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
753
793
 
794
+ # Internal arguments
754
795
  self.model = model.eval()
755
- generation_config = model.generation_config if generation_config is None else generation_config
756
- self.generation_config = generation_config
796
+ self.manual_eviction = manual_eviction
797
+ self._allow_block_sharing = allow_block_sharing
798
+ self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
799
+
757
800
  self.input_queue = queue.Queue(maxsize=max_queue_size)
758
801
  self.output_queue = queue.Queue()
759
802
  self.stop_event = threading.Event()
760
- self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
803
+ self.batch_processor: ContinuousBatchProcessor | None = None
761
804
  self._generation_thread = None
762
805
  self._request_counter = 0
763
806
  self._request_lock = threading.Lock()
764
- self.model.generation_config.top_p = None
807
+
808
+ # Generation config related arguments
809
+ generation_config = model.generation_config if generation_config is None else generation_config
810
+ self.generation_config = generation_config
811
+ self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
765
812
  self.do_sample = getattr(generation_config, "do_sample", True)
766
813
  self.logit_processor = self.model._get_logits_processor(generation_config)
767
- use_cuda_graph: bool | None = getattr(generation_config, "use_cuda_graph", None)
768
- self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
769
- self.manual_eviction = manual_eviction
770
- self.batch_processor: ContinuousBatchProcessor | None = None
814
+ self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1)
771
815
 
772
- self._allow_prefix_sharing = allow_prefix_sharing
816
+ # self.model.generation_config.top_p = None NOTE: figure out why this was here
773
817
 
774
- # If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
775
- if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
776
- self.use_cuda_graph = True
777
- # If use_cuda_graph is specified, we follow the user's choice
778
- elif use_cuda_graph is not None:
779
- self.use_cuda_graph = use_cuda_graph
780
- # If the use of cuda graphs is not specified, we follow the user's choice, otherwise we have a default heuristic
781
- else:
782
- # Attention implementations where an attention mask is needed suffer a lot more from the padding associated
783
- # with cuda graphs, so default is to turn cuda graphs off for those implementations
784
- self.use_cuda_graph = not attn_mask_is_needed(self.model.config)
785
- logger.warning(
786
- f"No behavior specified for use_cuda_graph, defaulting to {self.use_cuda_graph = } because "
787
- f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
788
- "they can improve performances."
789
- )
818
+ # Cuda graph behavior is determined below using either user-specified arguments or heuristics
819
+ self.use_cuda_graph = self._decide_use_cuda_graphs(
820
+ use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
821
+ num_q_padding_intervals=num_q_padding_intervals,
822
+ num_kv_padding_intervals=num_kv_padding_intervals,
823
+ compile_config=getattr(generation_config, "compile_config", None),
824
+ )
790
825
 
791
- # If cuda graphs are activated, we set the number of cuda graphs for Q and KV if not specified
792
- if self.use_cuda_graph:
793
- self.num_q_cuda_graphs = num_q_cuda_graphs if num_q_cuda_graphs > 0 else NUM_Q_CUDA_GRAPHS
794
- self.num_kv_cuda_graphs = num_kv_cuda_graphs if num_kv_cuda_graphs > 0 else NUM_KV_CUDA_GRAPHS
826
+ # We set the number of padding intervals for Q and KV
827
+ self.q_padding_intervals = num_q_padding_intervals if num_q_padding_intervals > 0 else NUM_Q_PADDING_INTERVALS
828
+ self.kv_padding_intervals = (
829
+ num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
830
+ )
795
831
 
832
+ # Log probability generation is not supported yet (TODO)
796
833
  if self.log_prob_generation:
797
834
  raise NotImplementedError("log_prob_generation is not supported yet")
798
835
 
836
+ def _decide_use_cuda_graphs(
837
+ self,
838
+ use_cuda_graph: bool | None,
839
+ num_q_padding_intervals: int,
840
+ num_kv_padding_intervals: int,
841
+ compile_config: CompileConfig | None,
842
+ ) -> bool:
843
+ """Returns whether or not to use cuda graphs for continuous batching, depending on the following criteria:
844
+ - (use_cuda_graph) which is the user choice
845
+ - (num_q_padding_intervals) or (num_kv_padding_intervals) which is used to pad inputs: if it was specified by
846
+ the user, it's probable they want to use cuda graphs so inputs need to be padded
847
+ - (compile_config): if compile is on, turn on cuda graphs unless the compile mode uses its own cudagraphs
848
+ If none of the above criteria are met, we use a default heuristic based on the attention implementation: we turn
849
+ on cuda graphs if and only if no attention mask is needed.
850
+ """
851
+ # If use_cuda_graph is specified, we follow the user's choice
852
+ if use_cuda_graph is not None:
853
+ return use_cuda_graph
854
+ # If a number of padding intervals was specified for either Q or KV, we activate cuda graphs
855
+ if num_q_padding_intervals > 0 or num_kv_padding_intervals > 0:
856
+ return True
857
+ # If a compile config was found, turn off cuda graphs if the compile config already uses them
858
+ if compile_config is not None:
859
+ options = torch._inductor.list_mode_options().get(compile_config.mode, compile_config.options)
860
+ compile_uses_cudagraphs = options.get("triton.cudagraphs", False)
861
+ if compile_uses_cudagraphs:
862
+ logger.warning(
863
+ f"Compile config {compile_config.mode = } uses cudagraphs, which usually does not work well with "
864
+ "continuous batching. We recommend using mode 'default' or 'max-autotune-no-cudagraphs' instead."
865
+ )
866
+ return not compile_uses_cudagraphs # TODO: should this also match the dynamic shapes?
867
+ # Otherwise we have a default heuristic based on the attention implementation:
868
+ # attention implementations where an attention mask is needed suffer a lot more from the padding associated
869
+ # with cuda graphs, so default is to turn cuda graphs off for those implementations
870
+ use_cuda_graph = not attn_mask_is_needed(self.model.config)
871
+ logger.warning(
872
+ f"No behavior specified for use_cuda_graph, defaulting to {use_cuda_graph = } because "
873
+ f"{self.model.config._attn_implementation = }. If you want to save memory, turn off cuda graphs, but "
874
+ "they can improve performances."
875
+ )
876
+ return use_cuda_graph
877
+
799
878
  @traced
800
879
  def start(self) -> None:
801
880
  """Start the background generation thread."""
@@ -822,7 +901,7 @@ class ContinuousBatchingManager:
822
901
  logger.warning("\nBatch processor was not initialized.")
823
902
  else:
824
903
  if self.batch_processor.cache.use_prefix_sharing:
825
- logger.warning(
904
+ logger.info(
826
905
  f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
827
906
  )
828
907
 
@@ -884,6 +963,7 @@ class ContinuousBatchingManager:
884
963
  state = RequestState(
885
964
  request_id=request_id,
886
965
  initial_tokens=list(input_ids),
966
+ num_children=self.num_return_sequences - 1,
887
967
  record_timestamps=record_timestamps,
888
968
  tokens_to_process=list(input_ids),
889
969
  max_new_tokens=max_new_tokens,
@@ -902,6 +982,10 @@ class ContinuousBatchingManager:
902
982
  streaming: bool = False,
903
983
  record_timestamps: bool = False,
904
984
  ) -> None:
985
+ # If there is prefix sharing, we sort the inputs to maximize cache hits
986
+ if self._use_prefix_sharing:
987
+ inputs = sorted(inputs, reverse=True)
988
+ # Add requests in order
905
989
  for input_ids in inputs:
906
990
  self.add_request(
907
991
  input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
@@ -972,8 +1056,9 @@ class ContinuousBatchingManager:
972
1056
  self.model.device,
973
1057
  self.model.dtype,
974
1058
  tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
975
- allow_prefix_sharing=self._allow_prefix_sharing,
1059
+ allow_block_sharing=self._allow_block_sharing,
976
1060
  )
1061
+ self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation
977
1062
  logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
978
1063
 
979
1064
  scheduler = None
@@ -999,6 +1084,8 @@ class ContinuousBatchingManager:
999
1084
  scheduler=scheduler(paged_attention_cache, self.manual_eviction),
1000
1085
  manual_eviction=self.manual_eviction,
1001
1086
  use_cuda_graph=self.use_cuda_graph,
1087
+ q_padding_intervals=self.q_padding_intervals,
1088
+ kv_padding_intervals=self.kv_padding_intervals,
1002
1089
  )
1003
1090
  self.batch_processor = batch_processor
1004
1091
  self.current_batch = 0
@@ -1024,13 +1111,12 @@ class ContinuousBatchingManager:
1024
1111
  # Debug logging of the current memory usage
1025
1112
  if logger.level <= logging.DEBUG:
1026
1113
  device, total, reserved, allocated = get_device_and_memory_breakdown()
1027
- logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
1114
+ available_memory = total - max(allocated, reserved)
1115
+ logger.debug(
1116
+ f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}, Available: {available_memory}"
1117
+ )
1028
1118
 
1029
1119
  self._generation_step()
1030
-
1031
- if torch.cuda.is_available():
1032
- torch.cuda.synchronize()
1033
- # Processor updates the batch after generation step is truly over
1034
1120
  batch_processor.update_batch()
1035
1121
 
1036
1122
  @traced
@@ -1072,7 +1158,7 @@ class ContinuousMixin:
1072
1158
  max_queue_size: int = 0,
1073
1159
  num_q_cuda_graphs: int = 0,
1074
1160
  num_kv_cuda_graphs: int = 0,
1075
- allow_prefix_sharing: bool = True,
1161
+ allow_block_sharing: bool = True,
1076
1162
  block: bool = True,
1077
1163
  timeout: float | None = None,
1078
1164
  ) -> Generator[ContinuousBatchingManager]:
@@ -1082,7 +1168,7 @@ class ContinuousMixin:
1082
1168
  max_queue_size,
1083
1169
  num_q_cuda_graphs,
1084
1170
  num_kv_cuda_graphs,
1085
- allow_prefix_sharing,
1171
+ allow_block_sharing,
1086
1172
  )
1087
1173
  manager.start()
1088
1174
  try:
@@ -1099,18 +1185,19 @@ class ContinuousMixin:
1099
1185
  generation_config: GenerationConfig | None = None,
1100
1186
  manual_eviction: bool = False,
1101
1187
  max_queue_size: int = 0,
1102
- num_q_cuda_graphs: int = 0,
1103
- num_kv_cuda_graphs: int = 0,
1104
- allow_prefix_sharing: bool = True,
1188
+ num_q_padding_intervals: int = 0,
1189
+ num_kv_padding_intervals: int = 0,
1190
+ allow_block_sharing: bool = True,
1105
1191
  ) -> ContinuousBatchingManager:
1106
1192
  """Initialize a manager for continuous batching inference.
1107
1193
 
1108
1194
  Args:
1109
- generation_config: Custom generation configuration
1195
+ generation_config: An optional generation configuration, which may contain a CompileConfig object
1110
1196
  manual_eviction: Whether to manually evict requests from the cache
1111
1197
  max_queue_size: Maximum size of the input request queue
1112
- num_q_cuda_graphs: Number of CUDA graphs to use for the query dimension
1113
- num_kv_cuda_graphs: Number of CUDA graphs to use for the keys/values dimension
1198
+ num_q_padding_intervals: Number of intervals used to pad the query dimension
1199
+ num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1200
+ allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
1114
1201
 
1115
1202
  Returns:
1116
1203
  `ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
@@ -1132,9 +1219,9 @@ class ContinuousMixin:
1132
1219
  generation_config=gen_config,
1133
1220
  manual_eviction=manual_eviction,
1134
1221
  max_queue_size=max_queue_size,
1135
- num_q_cuda_graphs=num_q_cuda_graphs,
1136
- num_kv_cuda_graphs=num_kv_cuda_graphs,
1137
- allow_prefix_sharing=allow_prefix_sharing,
1222
+ num_q_padding_intervals=num_q_padding_intervals,
1223
+ num_kv_padding_intervals=num_kv_padding_intervals,
1224
+ allow_block_sharing=allow_block_sharing,
1138
1225
  )
1139
1226
 
1140
1227
  # TODO: support streaming
@@ -1144,11 +1231,11 @@ class ContinuousMixin:
1144
1231
  self,
1145
1232
  inputs: list[list[int]],
1146
1233
  generation_config: GenerationConfig | None = None,
1147
- progress_bar: bool = True,
1148
- num_q_cuda_graphs: int = 0,
1149
- num_kv_cuda_graphs: int = 0,
1150
- allow_prefix_sharing: bool = True,
1234
+ num_q_padding_intervals: int = 0,
1235
+ num_kv_padding_intervals: int = 0,
1236
+ allow_block_sharing: bool = True,
1151
1237
  record_timestamps: bool = False,
1238
+ progress_bar: bool = True,
1152
1239
  **kwargs,
1153
1240
  ) -> dict[str, GenerationOutput]:
1154
1241
  """Generate sequences for a batch of prompts using continuous batching.
@@ -1156,14 +1243,15 @@ class ContinuousMixin:
1156
1243
  Args:
1157
1244
  inputs: List of input token sequences (prompts)
1158
1245
  generation_config: Optional generation configuration
1159
- num_q_cuda_graphs: Number of CUDA graphs to use for the query dimension
1160
- num_kv_cuda_graphs: Number of CUDA graphs to use for the keys/values dimension
1246
+ num_q_padding_intervals: Number of intervals used to pad the query dimension
1247
+ num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1248
+ allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
1249
+ record_timestamps: If set to true, the requests will have a timestamp for each token generated
1250
+ progress_bar: If set to true, a progress bar will be displayed
1161
1251
  **kwargs: Additional generation parameters
1162
1252
 
1163
1253
  Returns:
1164
- `list[list[int]]`: A list containing the generated sequences (including prompt tokens
1165
- if not handled otherwise) for each input prompt, in the same order.
1166
- Returns an empty list `[]` for requests that failed.
1254
+ `dict[str, GenerationOutput]`: a dictionary of request ids to GenerationOutput objects
1167
1255
  """
1168
1256
  if not inputs:
1169
1257
  return {}
@@ -1173,26 +1261,30 @@ class ContinuousMixin:
1173
1261
 
1174
1262
  # Initialize manager with the batch inputs
1175
1263
  results = {}
1176
- num_requests = len(inputs)
1177
- with (
1178
- self.continuous_batching_context_manager(
1179
- generation_config=generation_config,
1180
- num_q_cuda_graphs=num_q_cuda_graphs,
1181
- num_kv_cuda_graphs=num_kv_cuda_graphs,
1182
- allow_prefix_sharing=allow_prefix_sharing,
1183
- block=True,
1184
- timeout=5,
1185
- ) as manager,
1186
- logging_redirect_tqdm([logger]),
1187
- tqdm(
1188
- total=num_requests,
1189
- disable=(not progress_bar),
1190
- desc=f"Solving {num_requests} requests",
1191
- unit="request",
1192
- ) as pbar,
1193
- ):
1264
+ gen_cfg = self.generation_config if generation_config is None else generation_config
1265
+ num_requests = len(inputs) * gen_cfg.num_return_sequences
1266
+ # Prepare context managers for the main loop
1267
+ manager_cm = self.continuous_batching_context_manager(
1268
+ generation_config=generation_config,
1269
+ num_q_cuda_graphs=num_q_padding_intervals,
1270
+ num_kv_cuda_graphs=num_kv_padding_intervals,
1271
+ allow_block_sharing=allow_block_sharing,
1272
+ block=True,
1273
+ timeout=5,
1274
+ )
1275
+ logging_cm = logging_redirect_tqdm([logger])
1276
+ pbar_cm = tqdm(
1277
+ total=num_requests,
1278
+ disable=(not progress_bar),
1279
+ desc=f"Solving {num_requests} requests",
1280
+ unit="request",
1281
+ )
1282
+ # Main loop
1283
+ with manager_cm as manager, logging_cm, pbar_cm as pbar:
1194
1284
  try:
1195
- manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"))
1285
+ manager.add_requests(
1286
+ inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
1287
+ )
1196
1288
  finished_count = 0
1197
1289
  while finished_count < num_requests:
1198
1290
  result = manager.get_result(timeout=1)