transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -259,7 +259,7 @@ class ContinuousBatchProcessor:
259
259
  self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
260
260
  self.max_seqlen_q = 0
261
261
  self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
262
- self.output_ids = torch.empty((1, self.max_batch_tokens), **self.tensor_metadata)
262
+ self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
263
263
 
264
264
  # For some kwargs, we have a dict of tensors with as many items as there are attention types
265
265
  layer_types = getattr(self.config, "layer_types", None)
@@ -311,7 +311,7 @@ class ContinuousBatchProcessor:
311
311
  self.cumulative_seqlens_q[: b_size + 1].zero_()
312
312
  self.max_seqlen_q = 0
313
313
  self.logits_indices[:q_len].fill_(-1)
314
- self.output_ids[:, :q_len].fill_(-1)
314
+ self.output_ids[:q_len].fill_(-1)
315
315
 
316
316
  # Reset the attributes that are either tensors or dict of tensors
317
317
  for layer_type in self.cumulative_seqlens_k:
@@ -447,7 +447,7 @@ class ContinuousBatchProcessor:
447
447
  self.metrics.record_batch_metrics(self.requests_in_batch)
448
448
 
449
449
  # Reset the static tensors used for storage
450
- self.reset_static_tensors() # TODO: this might be unnecessary
450
+ self.reset_static_tensors() # FIXME: why does this make the generation faster?
451
451
 
452
452
  # Prepare accumulators
453
453
  self.actual_query_length = 0
@@ -557,13 +557,10 @@ class ContinuousBatchProcessor:
557
557
  self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
558
558
 
559
559
  @traced
560
- def _sync(self) -> list[int]:
561
- if self.output_ids is not None:
562
- try:
563
- return self.output_ids.tolist()[0]
564
- except Exception:
565
- return [0, 1]
566
- return [0, 0]
560
+ def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
561
+ indices = self.logits_indices[:num_new_tokens]
562
+ new_tokens = self.output_ids[indices]
563
+ return new_tokens.tolist()
567
564
 
568
565
  @traced
569
566
  def _maybe_send_output(self, state: RequestState) -> None:
@@ -574,29 +571,56 @@ class ContinuousBatchProcessor:
574
571
  @traced
575
572
  def update_batch(self) -> None:
576
573
  """Update request states based on generated tokens."""
577
- out_tokens = self._sync()
578
- for i, state in enumerate(self.requests_in_batch):
574
+ new_tokens = self._get_new_tokens(len(self.requests_in_batch))
575
+ current_logits_index = 0
576
+ for state in self.requests_in_batch:
579
577
  # If the request has no remaining prompt ids, it means prefill has already ended or just finished
580
578
  if len(state.remaining_prefill_tokens) == 0:
581
- self.metrics.record_ttft_metric(state.created_time, state.request_id)
582
- state.status = RequestStatus.DECODING
583
- token = out_tokens[self.logits_indices[i]]
579
+ # If there are no generated tokens yet, it means prefill just ended
580
+ if state.generated_len() == 0:
581
+ self.metrics.record_ttft_metric(state.created_time, state.request_id)
582
+ state.status = RequestStatus.DECODING
583
+
584
+ token = new_tokens[current_logits_index]
584
585
  state.tokens_to_process = [token]
586
+ current_logits_index += 1
587
+
585
588
  # Update the request and stop if it is complete
586
589
  is_finished = state.update_and_check_completion(token)
587
590
  # We mark the completed blocks as such
588
- self.cache.mark_blocks_as_complete(state)
591
+ self.cache.mark_shareable_blocks_as_complete(state)
589
592
  if is_finished:
590
593
  self.metrics.record_request_completion(state.created_time, state.request_id)
591
594
  self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
592
595
  self._maybe_send_output(state)
593
596
  # Otherwise, the request is still prefilling, but the prefill has been split
594
597
  elif state.status == RequestStatus.PREFILLING_SPLIT:
595
- self.cache.mark_blocks_as_complete(state)
598
+ self.cache.mark_shareable_blocks_as_complete(state)
596
599
  state.status = RequestStatus.SPLIT_PENDING_REMAINDER
597
600
  else:
598
601
  raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
599
602
 
603
+ # If some requests need to be forked, we do it now
604
+ copy_source, copy_destination = [], []
605
+ while self.scheduler._requests_to_fork:
606
+ # Get the number of children and reset it so it's not forked again
607
+ state = self.scheduler._requests_to_fork.pop()
608
+ num_children = state.num_children
609
+ state.num_children = 0
610
+ # Create the new request and add them to the scheduler
611
+ new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)]
612
+ for new_request_id in new_request_ids:
613
+ self.scheduler.active_requests[new_request_id] = state.fork(new_request_id)
614
+ # Fork the cache
615
+ copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids)
616
+ copy_source.extend(copy_src)
617
+ copy_destination.extend(copy_dst)
618
+ # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything
619
+
620
+ # The copy induced by the fork is done in one go (if it's even needed)
621
+ if copy_source:
622
+ self.cache.copy_cache(copy_source, copy_destination)
623
+
600
624
  if self.cache.get_num_free_blocks() == 0:
601
625
  raise ValueError("No more free blocks")
602
626
 
@@ -727,12 +751,11 @@ class ContinuousBatchProcessor:
727
751
  probs = nn.functional.softmax(probs, dim=-1)
728
752
  # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
729
753
  next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
730
- # Add batch dimension back to match argmax output
731
- next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
732
754
  else:
733
- next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
734
- tokens = next_tokens.size(1) # Get seq_len dimension
735
- self.output_ids[:, :tokens].copy_(next_tokens)
755
+ next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
756
+ next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
757
+ tokens = next_tokens.size(0) # Get seq_len dimension
758
+ self.output_ids[:tokens].copy_(next_tokens)
736
759
 
737
760
 
738
761
  # Manager Class (User Interface)
@@ -752,7 +775,7 @@ class ContinuousBatchingManager:
752
775
  max_queue_size: int = 0,
753
776
  num_q_padding_intervals: int = 0,
754
777
  num_kv_padding_intervals: int = 0,
755
- allow_prefix_sharing: bool = True,
778
+ allow_block_sharing: bool = True,
756
779
  ) -> None:
757
780
  """Initialize the continuous batching manager.
758
781
 
@@ -762,30 +785,37 @@ class ContinuousBatchingManager:
762
785
  max_queue_size: Maximum size of the request queue (0 = unlimited)
763
786
  num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
764
787
  num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
765
- allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
788
+ allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers
766
789
  """
767
- # Reloade paged version if necessary
790
+ # Reload paged version of the attention implementation if necessary
768
791
  if "paged|" not in model.config._attn_implementation:
769
792
  model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
770
793
 
794
+ # Internal arguments
771
795
  self.model = model.eval()
772
- generation_config = model.generation_config if generation_config is None else generation_config
773
- self.generation_config = generation_config
796
+ self.manual_eviction = manual_eviction
797
+ self._allow_block_sharing = allow_block_sharing
798
+ self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
799
+
774
800
  self.input_queue = queue.Queue(maxsize=max_queue_size)
775
801
  self.output_queue = queue.Queue()
776
802
  self.stop_event = threading.Event()
777
- self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
803
+ self.batch_processor: ContinuousBatchProcessor | None = None
778
804
  self._generation_thread = None
779
805
  self._request_counter = 0
780
806
  self._request_lock = threading.Lock()
781
- self.model.generation_config.top_p = None
807
+
808
+ # Generation config related arguments
809
+ generation_config = model.generation_config if generation_config is None else generation_config
810
+ self.generation_config = generation_config
811
+ self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
782
812
  self.do_sample = getattr(generation_config, "do_sample", True)
783
813
  self.logit_processor = self.model._get_logits_processor(generation_config)
784
- self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
785
- self.manual_eviction = manual_eviction
786
- self.batch_processor: ContinuousBatchProcessor | None = None
787
- self._allow_prefix_sharing = allow_prefix_sharing
814
+ self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1)
815
+
816
+ # self.model.generation_config.top_p = None NOTE: figure out why this was here
788
817
 
818
+ # Cuda graph behavior is determined below using either user-specified arguments or heuristics
789
819
  self.use_cuda_graph = self._decide_use_cuda_graphs(
790
820
  use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
791
821
  num_q_padding_intervals=num_q_padding_intervals,
@@ -799,6 +829,7 @@ class ContinuousBatchingManager:
799
829
  num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
800
830
  )
801
831
 
832
+ # Log probability generation is not supported yet (TODO)
802
833
  if self.log_prob_generation:
803
834
  raise NotImplementedError("log_prob_generation is not supported yet")
804
835
 
@@ -932,6 +963,7 @@ class ContinuousBatchingManager:
932
963
  state = RequestState(
933
964
  request_id=request_id,
934
965
  initial_tokens=list(input_ids),
966
+ num_children=self.num_return_sequences - 1,
935
967
  record_timestamps=record_timestamps,
936
968
  tokens_to_process=list(input_ids),
937
969
  max_new_tokens=max_new_tokens,
@@ -950,6 +982,10 @@ class ContinuousBatchingManager:
950
982
  streaming: bool = False,
951
983
  record_timestamps: bool = False,
952
984
  ) -> None:
985
+ # If there is prefix sharing, we sort the inputs to maximize cache hits
986
+ if self._use_prefix_sharing:
987
+ inputs = sorted(inputs, reverse=True)
988
+ # Add requests in order
953
989
  for input_ids in inputs:
954
990
  self.add_request(
955
991
  input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
@@ -1020,8 +1056,9 @@ class ContinuousBatchingManager:
1020
1056
  self.model.device,
1021
1057
  self.model.dtype,
1022
1058
  tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
1023
- allow_prefix_sharing=self._allow_prefix_sharing,
1059
+ allow_block_sharing=self._allow_block_sharing,
1024
1060
  )
1061
+ self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation
1025
1062
  logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
1026
1063
 
1027
1064
  scheduler = None
@@ -1080,10 +1117,6 @@ class ContinuousBatchingManager:
1080
1117
  )
1081
1118
 
1082
1119
  self._generation_step()
1083
-
1084
- if torch.cuda.is_available():
1085
- torch.cuda.synchronize() # FIXME: why is this needed?
1086
- # Processor updates the batch after generation step is truly over
1087
1120
  batch_processor.update_batch()
1088
1121
 
1089
1122
  @traced
@@ -1125,7 +1158,7 @@ class ContinuousMixin:
1125
1158
  max_queue_size: int = 0,
1126
1159
  num_q_cuda_graphs: int = 0,
1127
1160
  num_kv_cuda_graphs: int = 0,
1128
- allow_prefix_sharing: bool = True,
1161
+ allow_block_sharing: bool = True,
1129
1162
  block: bool = True,
1130
1163
  timeout: float | None = None,
1131
1164
  ) -> Generator[ContinuousBatchingManager]:
@@ -1135,7 +1168,7 @@ class ContinuousMixin:
1135
1168
  max_queue_size,
1136
1169
  num_q_cuda_graphs,
1137
1170
  num_kv_cuda_graphs,
1138
- allow_prefix_sharing,
1171
+ allow_block_sharing,
1139
1172
  )
1140
1173
  manager.start()
1141
1174
  try:
@@ -1154,7 +1187,7 @@ class ContinuousMixin:
1154
1187
  max_queue_size: int = 0,
1155
1188
  num_q_padding_intervals: int = 0,
1156
1189
  num_kv_padding_intervals: int = 0,
1157
- allow_prefix_sharing: bool = True,
1190
+ allow_block_sharing: bool = True,
1158
1191
  ) -> ContinuousBatchingManager:
1159
1192
  """Initialize a manager for continuous batching inference.
1160
1193
 
@@ -1164,7 +1197,7 @@ class ContinuousMixin:
1164
1197
  max_queue_size: Maximum size of the input request queue
1165
1198
  num_q_padding_intervals: Number of intervals used to pad the query dimension
1166
1199
  num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1167
- allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
1200
+ allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
1168
1201
 
1169
1202
  Returns:
1170
1203
  `ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
@@ -1188,7 +1221,7 @@ class ContinuousMixin:
1188
1221
  max_queue_size=max_queue_size,
1189
1222
  num_q_padding_intervals=num_q_padding_intervals,
1190
1223
  num_kv_padding_intervals=num_kv_padding_intervals,
1191
- allow_prefix_sharing=allow_prefix_sharing,
1224
+ allow_block_sharing=allow_block_sharing,
1192
1225
  )
1193
1226
 
1194
1227
  # TODO: support streaming
@@ -1200,7 +1233,7 @@ class ContinuousMixin:
1200
1233
  generation_config: GenerationConfig | None = None,
1201
1234
  num_q_padding_intervals: int = 0,
1202
1235
  num_kv_padding_intervals: int = 0,
1203
- allow_prefix_sharing: bool = True,
1236
+ allow_block_sharing: bool = True,
1204
1237
  record_timestamps: bool = False,
1205
1238
  progress_bar: bool = True,
1206
1239
  **kwargs,
@@ -1212,7 +1245,7 @@ class ContinuousMixin:
1212
1245
  generation_config: Optional generation configuration
1213
1246
  num_q_padding_intervals: Number of intervals used to pad the query dimension
1214
1247
  num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
1215
- allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers
1248
+ allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
1216
1249
  record_timestamps: If set to true, the requests will have a timestamp for each token generated
1217
1250
  progress_bar: If set to true, a progress bar will be displayed
1218
1251
  **kwargs: Additional generation parameters
@@ -1228,26 +1261,30 @@ class ContinuousMixin:
1228
1261
 
1229
1262
  # Initialize manager with the batch inputs
1230
1263
  results = {}
1231
- num_requests = len(inputs)
1232
- with (
1233
- self.continuous_batching_context_manager(
1234
- generation_config=generation_config,
1235
- num_q_cuda_graphs=num_q_padding_intervals,
1236
- num_kv_cuda_graphs=num_kv_padding_intervals,
1237
- allow_prefix_sharing=allow_prefix_sharing,
1238
- block=True,
1239
- timeout=5,
1240
- ) as manager,
1241
- logging_redirect_tqdm([logger]),
1242
- tqdm(
1243
- total=num_requests,
1244
- disable=(not progress_bar),
1245
- desc=f"Solving {num_requests} requests",
1246
- unit="request",
1247
- ) as pbar,
1248
- ):
1264
+ gen_cfg = self.generation_config if generation_config is None else generation_config
1265
+ num_requests = len(inputs) * gen_cfg.num_return_sequences
1266
+ # Prepare context managers for the main loop
1267
+ manager_cm = self.continuous_batching_context_manager(
1268
+ generation_config=generation_config,
1269
+ num_q_cuda_graphs=num_q_padding_intervals,
1270
+ num_kv_cuda_graphs=num_kv_padding_intervals,
1271
+ allow_block_sharing=allow_block_sharing,
1272
+ block=True,
1273
+ timeout=5,
1274
+ )
1275
+ logging_cm = logging_redirect_tqdm([logger])
1276
+ pbar_cm = tqdm(
1277
+ total=num_requests,
1278
+ disable=(not progress_bar),
1279
+ desc=f"Solving {num_requests} requests",
1280
+ unit="request",
1281
+ )
1282
+ # Main loop
1283
+ with manager_cm as manager, logging_cm, pbar_cm as pbar:
1249
1284
  try:
1250
- manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"))
1285
+ manager.add_requests(
1286
+ inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
1287
+ )
1251
1288
  finished_count = 0
1252
1289
  while finished_count < num_requests:
1253
1290
  result = manager.get_result(timeout=1)
@@ -101,6 +101,8 @@ class RequestState:
101
101
 
102
102
  Attributes:
103
103
  request_id (str): The ID of the generation request.
104
+ initial_tokens (list[int]): The initial prompt tokens.
105
+ num_children (int): The number of children requests
104
106
  full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
105
107
  prompt_ids (list[int] | None): The tokens IDs currently being processed.
106
108
  remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
@@ -121,6 +123,7 @@ class RequestState:
121
123
  initial_tokens: list[int] # Initial prompt tokens
122
124
  # Optional fields
123
125
  record_timestamps: bool = False # Whether to record timestamps for the generated tokens
126
+ num_children: int = 0 # Number of children requests
124
127
  # Internal fields
125
128
  tokens_to_process: list[int] | None = None # Tokens IDs currently being processed
126
129
  remaining_prefill_tokens: list[int] = field(default_factory=list) # For split requests, prefill left to process
@@ -181,7 +184,7 @@ class RequestState:
181
184
  Returns:
182
185
  bool: True if the request is now complete, False otherwise
183
186
  """
184
- # Only update if we're in decoding state
187
+ # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this
185
188
  if self.status != RequestStatus.DECODING:
186
189
  return False
187
190
 
@@ -227,3 +230,27 @@ class RequestState:
227
230
  error=self.error,
228
231
  timestamps=self.timestamps,
229
232
  )
233
+
234
+ def fork(self, new_request_id: str) -> "RequestState":
235
+ """Fork the request into a new request with the same state expect for request_id, created_time and lifespan."""
236
+ t = time.perf_counter()
237
+ new_request = RequestState(
238
+ request_id=new_request_id,
239
+ initial_tokens=self.initial_tokens,
240
+ num_children=self.num_children,
241
+ tokens_to_process=self.tokens_to_process[:],
242
+ remaining_prefill_tokens=self.remaining_prefill_tokens[:],
243
+ generated_tokens=self.generated_tokens[:],
244
+ allocated_blocks=self.allocated_blocks,
245
+ position_offset=self.position_offset,
246
+ status=self.status,
247
+ max_new_tokens=self.max_new_tokens,
248
+ eos_token_id=self.eos_token_id,
249
+ streaming=self.streaming,
250
+ created_time=t,
251
+ lifespan=(t, -1),
252
+ timestamps=None if self.timestamps is None else self.timestamps[:],
253
+ error=self.error,
254
+ record_timestamps=self.record_timestamps,
255
+ )
256
+ return new_request
@@ -36,6 +36,7 @@ class Scheduler(ABC):
36
36
  self.retain_cache_on_finish = retain_cache_on_finish
37
37
  self._cancellation_lock = threading.Lock()
38
38
  self._requests_to_cancel: set[str] = set()
39
+ self._requests_to_fork: list[RequestState] = []
39
40
 
40
41
  @traced
41
42
  def add_waiting_request(self, state: RequestState):
@@ -151,8 +152,13 @@ class Scheduler(ABC):
151
152
  else:
152
153
  request_tokens = state.tokens_to_process
153
154
 
155
+ # If the request has one or more children we make sure not to prefill it entrirely
156
+ if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
157
+ token_budget = len(request_tokens) - 1
158
+ self._requests_to_fork.append(state)
159
+
160
+ # Case: we can process the entire prompt/remainder
154
161
  if len(request_tokens) < token_budget:
155
- # Can process the entire prompt/remainder
156
162
  if state.status == RequestStatus.PENDING:
157
163
  self.active_requests[state.request_id] = state
158
164
  state.status = RequestStatus.PREFILLING
@@ -161,8 +167,9 @@ class Scheduler(ABC):
161
167
  state.status = RequestStatus.PREFILLING
162
168
  state.tokens_to_process = state.remaining_prefill_tokens
163
169
  state.remaining_prefill_tokens = []
170
+
171
+ # Otherwise: we need to split the request
164
172
  else:
165
- # Need to split the request
166
173
  if state.status == RequestStatus.PENDING:
167
174
  self.active_requests[state.request_id] = state
168
175
  state.status = RequestStatus.PREFILLING_SPLIT
@@ -229,7 +236,7 @@ class FIFOScheduler(Scheduler):
229
236
  # Update the token budget
230
237
  token_budget -= request_len
231
238
  # If using prefix sharing, we make note of the blocks that will be computed in the forward pass
232
- if self.cache.use_prefix_sharing:
239
+ if self.cache.allow_block_sharing:
233
240
  tokens_in_current_block = state.current_len() % self.cache.block_size
234
241
  tokens_after_forward = tokens_in_current_block + request_len
235
242
  complete_blocks = tokens_after_forward // self.cache.block_size
@@ -295,7 +302,7 @@ class PrefillFirstScheduler(Scheduler):
295
302
  # Update the token budget
296
303
  token_budget -= request_len
297
304
  # If using prefix sharing, we make note of the blocks that will be computed in the forward pass
298
- if self.cache.use_prefix_sharing:
305
+ if self.cache.allow_block_sharing:
299
306
  tokens_in_current_block = state.current_len() % self.cache.block_size
300
307
  tokens_after_forward = tokens_in_current_block + request_len
301
308
  complete_blocks = tokens_after_forward // self.cache.block_size
@@ -430,7 +430,7 @@ class StopStringCriteria(StoppingCriteria):
430
430
  initial_match = end_lengths > 0
431
431
 
432
432
  # Tokens continue the string if the cumsum() so far is one of the valid positions for that token
433
- # Note that we're actually tracking one cumsum() for for each possible end_length
433
+ # Note that we're actually tracking one cumsum() for each possible end_length
434
434
  later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
435
435
 
436
436
  # The match vector is a boolean vector that indicates which positions have valid tokens