transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -121,7 +121,7 @@ class PagedAttentionCache:
121
121
  device: torch.device,
122
122
  dtype: torch.dtype = torch.float16,
123
123
  tp_size: int | None = None,
124
- allow_prefix_sharing: bool = True,
124
+ allow_block_sharing: bool = True,
125
125
  ) -> None:
126
126
  """Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
127
127
  only full attention layers.
@@ -132,7 +132,8 @@ class PagedAttentionCache:
132
132
  device: Device for the cache tensors
133
133
  dtype: Data type of the cache
134
134
  tp_size: Tensor parallelism size
135
- allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
135
+ allow_block_sharing: A flag to allow block sharing. If the model has some full attention layers, then prefix
136
+ sharing is enabled as well.
136
137
  """
137
138
  self.config = config
138
139
  self.dtype = dtype
@@ -209,7 +210,7 @@ class PagedAttentionCache:
209
210
  self.key_cache: list[torch.Tensor] = []
210
211
  self.value_cache: list[torch.Tensor] = []
211
212
  # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
212
- self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim)
213
+ self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim)
213
214
  for _ in range(group_size):
214
215
  new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
215
216
  new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
@@ -220,19 +221,20 @@ class PagedAttentionCache:
220
221
  logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
221
222
 
222
223
  # Block management data structures
224
+ self.allow_block_sharing = allow_block_sharing
223
225
  self.group_cache_managers: list[CacheAllocator] = []
224
226
  for i, group_type in enumerate(group_types):
225
227
  if group_type == "full_attention":
226
- cm = FullAttentionCacheAllocator(i, self.block_size)
228
+ cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=allow_block_sharing)
227
229
  elif group_type == "sliding_attention":
228
230
  cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
229
231
  else:
230
232
  raise ValueError(f"Invalid group type: {group_type}")
231
233
  self.group_cache_managers.append(cm)
232
234
 
233
- # We only use prefix sharing if the whole model has only full attention layers
234
- self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
235
- self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
235
+ # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
236
+ self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"]
237
+ self._block_manager = BlockManager(num_blocks, self.block_size)
236
238
  self.blocks_to_complete: dict[str, int] = {}
237
239
  self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
238
240
 
@@ -352,7 +354,8 @@ class PagedAttentionCache:
352
354
  allocated_blocks = []
353
355
  for b in range(len(prompt_ids) // self.block_size):
354
356
  tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
355
- current_hash = self._block_manager.compute_hash(current_hash, tokens)
357
+ # Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
358
+ current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0)
356
359
  block_id = self._block_manager._hash_to_id.get(current_hash)
357
360
  if block_id is not None:
358
361
  allocated_blocks.append(block_id)
@@ -369,18 +372,44 @@ class PagedAttentionCache:
369
372
  self._total_prefix_length += prefix_length
370
373
  return prefix_length
371
374
 
372
- def mark_blocks_as_complete(self, state: RequestState) -> None:
373
- """Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
374
- a no-op."""
375
- num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
375
+ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
376
+ """Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed
377
+ in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has
378
+ enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for
379
+ the N tokens. If block sharing is off, this is a no-op."""
380
+ num_complete_blocks = 0 if not self.allow_block_sharing else self.blocks_to_complete.pop(state.request_id)
376
381
  if num_complete_blocks == 0:
377
382
  return None
378
- cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
379
- self._block_manager.mark_blocks_as_complete(
380
- num_complete_blocks=num_complete_blocks,
381
- allocated_blocks=cm.block_table[state.request_id],
382
- prompt_ids=(state.initial_tokens + state.generated_tokens),
383
- )
383
+ for cm in self.group_cache_managers:
384
+ if cm.uses_block_sharing:
385
+ self._block_manager.mark_shareable_blocks_as_complete(
386
+ num_complete_blocks=num_complete_blocks,
387
+ allocated_blocks=cm.block_table[state.request_id],
388
+ prompt_ids=(state.initial_tokens + state.generated_tokens),
389
+ )
390
+
391
+ def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
392
+ """Copy the cache from the source blocks to the forked blocks."""
393
+ source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
394
+ forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
395
+ for key_cache, value_cache in zip(self.key_cache, self.value_cache):
396
+ key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
397
+ value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
398
+ key_cache[forked_blocks] = key_cache[source_blocks]
399
+ value_cache[forked_blocks] = value_cache[source_blocks]
400
+ # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape)
401
+ # This will allow for better .update and a single copy instead of one per cache tensor
402
+
403
+ def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]:
404
+ """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids)."""
405
+ # These lists will be the accumulators for the source and destination blocks for the cache copy
406
+ source_blocks, destination_blocks = [], []
407
+ # Main fork loop
408
+ for cm in self.group_cache_managers:
409
+ src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager)
410
+ source_blocks.extend(src_blocks)
411
+ destination_blocks.extend(dst_blocks)
412
+ return source_blocks, destination_blocks
384
413
 
385
414
 
386
415
  # TODO: rework computation with the groups and their sizes
@@ -31,20 +31,21 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
31
31
  index -= 1
32
32
 
33
33
 
34
- class Block:
34
+ class Block: # TODO: rename to ShareableBlock and update the docs
35
35
  """A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
36
36
  cache it points to is fully computed. A block can have a parent, which is the block that came before in the
37
- sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
38
- its parent's hash (if there is a parent)."""
37
+ sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block, the
38
+ layer (group_id) it belong to and its parent's hash (if there is a parent)."""
39
39
 
40
- def __init__(self, id_: int, parent_id: int | None) -> None:
40
+ def __init__(self, id_: int, parent_id: int | None, group_id: int) -> None:
41
41
  self.id: int = id_
42
42
  self.parent_id: int | None = parent_id
43
+ self.group_id: int = group_id
43
44
  self.hash: int | None = None
44
45
  self.ref_count: int = 1
45
46
 
46
47
  def __repr__(self) -> str:
47
- return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
48
+ return f"Block(id={self.id}, parent_id={self.parent_id}, group_id={self.group_id}, hash={self.hash}, ref_count={self.ref_count})"
48
49
 
49
50
  @property
50
51
  def is_complete(self) -> bool:
@@ -52,8 +53,9 @@ class Block:
52
53
 
53
54
 
54
55
  class BlockManager:
55
- """A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
56
- simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
56
+ """A class to manage the number of free blocks and block re-use. When a block becomes in use, a flag is passed to
57
+ determine if the block is shareable or not. If it is, then a Block object is created and kept track of internally.
58
+ It can have the following states:
57
59
  - in use: one or more requests references this block, thus it cannot be written over. The number of requests
58
60
  referencing this block is stored as ref_count in the Block object.
59
61
  - un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
@@ -63,19 +65,19 @@ class BlockManager:
63
65
  the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
64
66
  Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
65
67
  hash table.
68
+ If the block is not shareable, we just use the block manager as a FIFO structure where blocks are either free or in
69
+ use. Sharability is determined by the type of cache allocator: blocks created for full attention layers are
70
+ shareable, while blocks created for sliding window attention layers are not.
66
71
  There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
67
72
  it is in use.
68
73
  """
69
74
 
70
- def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
71
- """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
72
- can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
73
- layers."""
75
+ def __init__(self, num_blocks: int, block_size: int) -> None:
76
+ """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
74
77
  self.num_blocks = num_blocks
75
78
  self.block_size = block_size
76
79
  self._uninit_block_ids = deque(range(num_blocks))
77
80
  self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
78
- self._use_prefix_sharing = use_prefix_sharing
79
81
  self._hash_to_id: dict[int, int] = {}
80
82
  self._id_to_block: dict[int, Block] = {}
81
83
 
@@ -102,22 +104,81 @@ class BlockManager:
102
104
  self._uninit_block_ids.append(id_to_uninitialize)
103
105
  return True
104
106
 
105
- def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
106
- """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
107
- can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
108
- the parent block. If the manager cannot find enough free blocks, it returns None."""
107
+ def get_free_blocks(
108
+ self, n_blocks: int, last_block_id: int | None, shareable: bool, group_id: int
109
+ ) -> list[int] | None:
110
+ """Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures.
111
+ If the (shareable) flag is set to True, a Block object is created to keep track of the block, with the
112
+ (last_block_id) to indicate the last block id in the sequence, also named the parent block. If the manager
113
+ cannot find enough free blocks, it returns None."""
109
114
  if not self.has_enough_free_blocks(n_blocks):
110
115
  return None
111
116
  allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
112
- # If we use prefix caching, we keep track of the allocated blocks as partial blocks
113
- if self._use_prefix_sharing:
117
+ # If the block is shareable, we keep track of the allocated blocks as partial blocks
118
+ if shareable:
114
119
  for block_id in allocated_block_ids:
115
- block = Block(block_id, last_block_id)
120
+ block = Block(block_id, last_block_id, group_id)
116
121
  self._id_to_block[block_id] = block
117
122
  last_block_id = block_id
118
123
  # In both cases, we return the allocated block ids
119
124
  return allocated_block_ids
120
125
 
126
+ def fork_blocks(
127
+ self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int
128
+ ) -> tuple[list[list[int]], list[int], list[int]]:
129
+ """Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use
130
+ reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to
131
+ later copy the physical cache. For instance, when forking 4 blocks for 2 children:
132
+
133
+ Parent blocks: [0, 1, 2, 3], with all blocks being complete except the last one (block 3).
134
+
135
+ ----------------------------------------- IF BLOCKS ARE NOT SHAREABLE -----------------------------------------
136
+
137
+ Forked blocks lists: [[5, 6, 7, 8], [9, 10, 11, 12]]
138
+ Copy source: [0, 1, 2, 3, 0, 1, 2, 3]
139
+ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
140
+ Copy destination: [5, 6, 7, 8, 9, 10, 11, 12] → 8 blocks are newly allocated and copied
141
+
142
+ ----------------------------------------- IF BLOCKS ARE SHAREABLE ---------------------------------------------
143
+
144
+ Forked blocks lists: [[0, 1, 2, 5], [0, 1, 2, 6]]
145
+ Copy source: [ 3, 3] (block 3 is not complete so it's copied, not referenced)
146
+ ↓ ↓
147
+ Copy destination: [ 5, 6] → only 2 blocks are newly allocated and copied
148
+ """
149
+ # First phase: reference all complete blocks
150
+ forked_by_reference = []
151
+
152
+ if shareable:
153
+ for block_id in parent_blocks:
154
+ block = self._id_to_block[block_id]
155
+ if block.is_complete:
156
+ forked_by_reference.append(block.id)
157
+ block.ref_count += num_forks
158
+ else:
159
+ break
160
+
161
+ # Early return if we have forked all blocks by reference
162
+ blocks_to_copy = len(parent_blocks) - len(forked_by_reference)
163
+ if blocks_to_copy == 0:
164
+ return [forked_by_reference[:] for _ in range(num_forks)], [], []
165
+
166
+ # From now on, each child will have its own list of blocks
167
+ forked_blocks_lists = []
168
+ copy_src = []
169
+ copy_dst = []
170
+
171
+ # Second phase: allocate new blocks if needed
172
+ parent_id = forked_by_reference[-1] if forked_by_reference else None
173
+ for _ in range(num_forks):
174
+ allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id)
175
+ if allocated_block_ids is None:
176
+ return None, [], []
177
+ forked_blocks_lists.append(forked_by_reference + allocated_block_ids)
178
+ copy_src.extend(parent_blocks[-blocks_to_copy:])
179
+ copy_dst.extend(allocated_block_ids)
180
+ return forked_blocks_lists, copy_src, copy_dst
181
+
121
182
  def increase_ref_count(self, block_id: int) -> None:
122
183
  """Increases the reference count of a given (block_id)."""
123
184
  block = self._id_to_block[block_id]
@@ -137,23 +198,23 @@ class BlockManager:
137
198
  self._id_to_block.pop(block_id)
138
199
  self._uninit_block_ids.append(block_id)
139
200
 
140
- def free_blocks(self, blocks: list[int]) -> None:
141
- """Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
201
+ def free_blocks(self, blocks: list[int], shareable: bool) -> None:
202
+ """Marks a list of (blocks) as free. If the blocks were not (shareable), we simply add them to the uninitialized
142
203
  blocks queue. Otherwise, their new state depends on whether they are complete."""
143
- if self._use_prefix_sharing:
204
+ if shareable:
144
205
  for block_id in blocks:
145
206
  self.decrease_ref_count(block_id)
146
207
  else:
147
208
  self._uninit_block_ids.extend(blocks)
148
209
 
149
- def mark_blocks_as_complete(
210
+ def mark_shareable_blocks_as_complete(
150
211
  self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
151
212
  ) -> None:
152
213
  """Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
153
214
  of (prompt_ids) is used to compute the hash of the new block."""
154
215
  # Look for the first complete block, starting from the last block in the sequence
155
216
  parent_hash = None
156
- incomplete_blocks: list[Block] = []
217
+ incomplete_blocks: list[tuple[int, Block]] = []
157
218
  for i, block_id in reverse_enumerate(allocated_blocks):
158
219
  block = self._id_to_block[block_id]
159
220
  if block.is_complete:
@@ -178,7 +239,7 @@ class BlockManager:
178
239
  # Otherwise, we compute the hash
179
240
  num_complete_blocks -= 1
180
241
  tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
181
- block.hash = self.compute_hash(parent_hash, tokens)
242
+ block.hash = self.compute_hash(parent_hash, tokens, block.group_id)
182
243
 
183
244
  existing_block_id = self._hash_to_id.get(block.hash)
184
245
  # If the block hash is already in the hash to id mapping, we reference the existing block instead
@@ -187,19 +248,20 @@ class BlockManager:
187
248
  allocated_blocks[i] = existing_block_id
188
249
  self._id_to_block[existing_block_id].ref_count += 1
189
250
  new_parent_id = existing_block_id
190
- self.free_blocks([block.id])
251
+ self.free_blocks([block.id], shareable=True)
191
252
 
192
253
  # Otherwise, we add the completed block to the hash table
193
254
  else:
255
+ logger.debug(f"Adding new block {block.id} (group {block.group_id}) with hash {block.hash}")
194
256
  self._hash_to_id[block.hash] = block.id
195
257
 
196
258
  # Update loop variables
197
259
  parent_hash = block.hash
198
260
 
199
- def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
200
- """Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
201
- parent, the parent hash is None."""
202
- return hash((parent_hash, tuple(tokens)))
261
+ def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int:
262
+ """Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
263
+ (group_id) it belong to. If the block has no parent, the parent hash is None."""
264
+ return hash((parent_hash, tuple(tokens), group_id))
203
265
 
204
266
 
205
267
  class CacheAllocator(ABC):
@@ -208,6 +270,7 @@ class CacheAllocator(ABC):
208
270
 
209
271
  _index: int
210
272
  block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
273
+ uses_block_sharing: bool # flag to determine if the blocks are shareable
211
274
 
212
275
  @abstractmethod
213
276
  def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> int | None:
@@ -218,7 +281,7 @@ class CacheAllocator(ABC):
218
281
  """Frees all blocks associated with a (request_id) using the (block_manager)."""
219
282
  if request_id in self.block_table:
220
283
  blocks_to_free = self.block_table.pop(request_id)
221
- block_manager.free_blocks(blocks_to_free)
284
+ block_manager.free_blocks(blocks_to_free, shareable=self.uses_block_sharing)
222
285
  else:
223
286
  logger.warning(
224
287
  f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
@@ -236,17 +299,48 @@ class CacheAllocator(ABC):
236
299
  def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
237
300
  """Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
238
301
 
302
+ def fork_blocks(
303
+ self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager
304
+ ) -> tuple[list[int], list[int]]:
305
+ """Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks,
306
+ the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
307
+ be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and
308
+ one for the destination."""
309
+
310
+ # Sanity checks
311
+ if parent_request_id not in self.block_table:
312
+ raise ValueError(f"No block table found for request {parent_request_id}")
313
+
314
+ # Actual forking
315
+ parent_blocks = self.block_table[parent_request_id]
316
+ list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks(
317
+ parent_blocks=parent_blocks,
318
+ num_forks=len(children_request_ids),
319
+ shareable=self.uses_block_sharing,
320
+ group_id=self._index,
321
+ )
322
+ if list_forked_blocks is None:
323
+ raise ValueError(f"Failed to fork blocks for request {parent_request_id}")
324
+
325
+ # Update the block table for all children requests
326
+ for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks):
327
+ if children_request_id in self.block_table:
328
+ raise ValueError(f"Block table already exists for request {children_request_id}")
329
+ self.block_table[children_request_id] = forked_blocks
330
+ return copy_src, copy_dst
331
+
239
332
 
240
333
  class FullAttentionCacheAllocator(CacheAllocator):
241
334
  """Cache manager for a group of full attention layers."""
242
335
 
243
- def __init__(self, index: int, block_size: int) -> None:
336
+ def __init__(self, index: int, block_size: int, allow_block_sharing: bool) -> None:
244
337
  """Initializes the cache manager for a group of full attention layers.
245
338
  Args:
246
339
  - index: the index of the associated layer group
247
340
  - block_size: the size of the blocks in the cache
248
341
  """
249
342
  self._index = index
343
+ self.uses_block_sharing = allow_block_sharing
250
344
  self.block_size = block_size
251
345
  self.block_table = {}
252
346
 
@@ -261,7 +355,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
261
355
  else:
262
356
  last_block_id = self.block_table[request_id][-1]
263
357
  # Actual allocation, return early if failed
264
- allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
358
+ allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id, self.uses_block_sharing, self._index)
265
359
  if allocated_blocks is None:
266
360
  return None
267
361
  self.block_table[request_id].extend(allocated_blocks)
@@ -315,6 +409,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
315
409
  - sliding_window: the size of the sliding window
316
410
  """
317
411
  self._index = index
412
+ self.uses_block_sharing = False
318
413
  self.block_size = block_size
319
414
  self.sliding_window = sliding_window
320
415
  self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
@@ -334,7 +429,9 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
334
429
  after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
335
430
  actual_n_blocks = after_allocation - already_allocated
336
431
  # Classic allocation
337
- allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
432
+ allocated_blocks = block_manager.get_free_blocks(
433
+ actual_n_blocks, None, self.uses_block_sharing, self._index
434
+ ) # no block sharing w/ sliding window
338
435
  if allocated_blocks is None:
339
436
  return None
340
437
  self.block_table[request_id].extend(allocated_blocks)