transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,7 @@ import torch
20
20
  from torchvision.transforms.v2 import functional as F
21
21
 
22
22
  from ...image_processing_utils import BatchFeature
23
- from ...image_processing_utils_fast import (
24
- BaseImageProcessorFast,
25
- group_images_by_shape,
26
- reorder_images,
27
- )
23
+ from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
28
24
  from ...image_transforms import get_resize_output_image_size
29
25
  from ...image_utils import (
30
26
  IMAGENET_STANDARD_MEAN,
@@ -32,6 +28,7 @@ from ...image_utils import (
32
28
  ChannelDimension,
33
29
  ImageInput,
34
30
  PILImageResampling,
31
+ SizeDict,
35
32
  )
36
33
  from ...processing_utils import Unpack
37
34
  from ...utils import (
@@ -43,7 +40,7 @@ from .image_processing_convnext import ConvNextImageProcessorKwargs
43
40
 
44
41
  @auto_docstring
45
42
  class ConvNextImageProcessorFast(BaseImageProcessorFast):
46
- resample = PILImageResampling.BILINEAR
43
+ resample = PILImageResampling.BICUBIC
47
44
  image_mean = IMAGENET_STANDARD_MEAN
48
45
  image_std = IMAGENET_STANDARD_STD
49
46
  size = {"shortest_edge": 384}
@@ -98,23 +95,23 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
98
95
  resize_size = get_resize_output_image_size(
99
96
  image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
100
97
  )
101
- image = F.resize(
98
+ image = super().resize(
102
99
  image,
103
- resize_size,
100
+ SizeDict(height=resize_size[0], width=resize_size[1]),
104
101
  interpolation=interpolation,
105
102
  **kwargs,
106
103
  )
107
104
  # then crop to (shortest_edge, shortest_edge)
108
- return F.center_crop(
105
+ return self.center_crop(
109
106
  image,
110
- (shortest_edge, shortest_edge),
107
+ SizeDict(height=shortest_edge, width=shortest_edge),
111
108
  **kwargs,
112
109
  )
113
110
  else:
114
111
  # warping (no cropping) when evaluated at 384 or larger
115
- return F.resize(
112
+ return super().resize(
116
113
  image,
117
- (shortest_edge, shortest_edge),
114
+ SizeDict(height=shortest_edge, width=shortest_edge),
118
115
  interpolation=interpolation,
119
116
  **kwargs,
120
117
  )
@@ -162,7 +159,6 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
162
159
  processed_images_grouped[shape] = stacked_images
163
160
 
164
161
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
165
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
166
162
 
167
163
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
168
164
 
@@ -268,7 +268,7 @@ class ConvNextModel(ConvNextPreTrainedModel):
268
268
  @can_return_tuple
269
269
  @auto_docstring
270
270
  def forward(
271
- self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
271
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, **kwargs
272
272
  ) -> BaseModelOutputWithPoolingAndNoAttention:
273
273
  if output_hidden_states is None:
274
274
  output_hidden_states = self.config.output_hidden_states
@@ -370,9 +370,7 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
370
370
  @can_return_tuple
371
371
  @auto_docstring
372
372
  def forward(
373
- self,
374
- pixel_values: torch.Tensor,
375
- output_hidden_states: Optional[bool] = None,
373
+ self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
376
374
  ) -> BackboneOutput:
377
375
  r"""
378
376
  Examples:
@@ -289,7 +289,7 @@ class ConvNextV2Model(ConvNextV2PreTrainedModel):
289
289
  @can_return_tuple
290
290
  @auto_docstring
291
291
  def forward(
292
- self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
292
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, **kwargs
293
293
  ) -> BaseModelOutputWithPoolingAndNoAttention:
294
294
  if output_hidden_states is None:
295
295
  output_hidden_states = self.config.output_hidden_states
@@ -393,9 +393,7 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
393
393
  @can_return_tuple
394
394
  @auto_docstring
395
395
  def forward(
396
- self,
397
- pixel_values: torch.Tensor,
398
- output_hidden_states: Optional[bool] = None,
396
+ self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
399
397
  ) -> BackboneOutput:
400
398
  r"""
401
399
  Examples:
@@ -89,7 +89,7 @@ class CsmGenerationMixin(GenerationMixin):
89
89
  return kept_criteria
90
90
 
91
91
  def _prepare_generation_config(
92
- self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
92
+ self, generation_config: Optional[GenerationConfig], **kwargs: Any
93
93
  ) -> tuple[GenerationConfig, dict]:
94
94
  """
95
95
  This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
@@ -104,9 +104,7 @@ class CsmGenerationMixin(GenerationMixin):
104
104
  kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
105
105
 
106
106
  # initialize the generation config
107
- generation_config, model_kwargs = super()._prepare_generation_config(
108
- generation_config, use_model_defaults, **kwargs
109
- )
107
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
110
108
  self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
111
109
 
112
110
  # ensure the depth decoder generation config is valid
@@ -209,26 +207,25 @@ class CsmGenerationMixin(GenerationMixin):
209
207
  else self.__call__
210
208
  )
211
209
 
212
- is_prefill = True
213
- while self._has_unfinished_sequences(
214
- this_peer_finished,
215
- synced_gpus,
216
- device=input_ids.device,
217
- ):
218
- # prepare model inputs
219
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
220
-
221
- # prepare variable output controls (note: some models won't accept all output controls)
222
- model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
223
- # *************** Csm specific ***************
224
- model_inputs.update({"output_hidden_states": True})
225
- # ============================================
210
+ # *************** Csm specific ***************
211
+ model_kwargs.update({"output_hidden_states": True})
226
212
 
227
- if is_prefill:
228
- outputs = self(**model_inputs, return_dict=True)
229
- is_prefill = False
230
- else:
213
+ # Assisted generation completes the prefill stage in candidate generator so that
214
+ # we don't have several `prefill` calls in one generation loop. Skip `_prefill` for assistants
215
+ if not generation_config.is_assistant:
216
+ outputs = self._prefill(input_ids, generation_config, model_kwargs)
217
+ prefill_consumed = False
218
+ else:
219
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
220
+ prefill_consumed = True
221
+
222
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
223
+ if prefill_consumed:
224
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
225
+ # prepare variable output controls (note: some models won't accept all output controls)
226
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
231
227
  outputs = model_forward(**model_inputs, return_dict=True)
228
+ prefill_consumed = True
232
229
 
233
230
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
234
231
  model_kwargs = self._update_model_kwargs_for_generation(
@@ -32,7 +32,7 @@ from ... import initialization as init
32
32
  from ...activations import ACT2FN
33
33
  from ...cache_utils import Cache, DynamicCache
34
34
  from ...generation import GenerationMixin
35
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
35
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
36
36
  from ...masking_utils import create_causal_mask
37
37
  from ...modeling_layers import GradientCheckpointingLayer
38
38
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -40,6 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
43
+ from ...utils.generic import maybe_autocast
43
44
  from ...utils.import_utils import is_torchdynamo_compiling
44
45
  from ..auto import AutoModel
45
46
  from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
@@ -135,7 +136,7 @@ class CsmRotaryEmbedding(nn.Module):
135
136
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
136
137
 
137
138
  self.register_buffer("inv_freq", inv_freq, persistent=False)
138
- self.original_inv_freq = inv_freq
139
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
139
140
 
140
141
  @staticmethod
141
142
  def compute_default_rope_parameters(
@@ -174,7 +175,7 @@ class CsmRotaryEmbedding(nn.Module):
174
175
  position_ids_expanded = position_ids[:, None, :].float()
175
176
 
176
177
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
177
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
178
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
178
179
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
179
180
  emb = torch.cat((freqs, freqs), dim=-1)
180
181
  cos = emb.cos() * self.attention_scaling
@@ -272,6 +273,7 @@ def eager_attention_forward(
272
273
  return attn_output, attn_weights
273
274
 
274
275
 
276
+ @use_kernelized_func(apply_rotary_pos_emb)
275
277
  class CsmAttention(nn.Module):
276
278
  """Multi-headed attention from 'Attention Is All You Need' paper"""
277
279
 
@@ -297,7 +299,6 @@ class CsmAttention(nn.Module):
297
299
  self.o_proj = nn.Linear(
298
300
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
299
301
  )
300
- self.rotary_fn = apply_rotary_pos_emb
301
302
 
302
303
  def forward(
303
304
  self,
@@ -420,6 +421,8 @@ class CsmPreTrainedModel(PreTrainedModel):
420
421
  num_codebooks = module.num_codebooks
421
422
  for i in range(num_codebooks - 1):
422
423
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
424
+ elif isinstance(module, CsmBackboneModelEmbeddings):
425
+ init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
423
426
 
424
427
 
425
428
  @auto_docstring
@@ -149,6 +149,8 @@ class CsmPreTrainedModel(PreTrainedModel):
149
149
  num_codebooks = module.num_codebooks
150
150
  for i in range(num_codebooks - 1):
151
151
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
152
+ elif isinstance(module, CsmBackboneModelEmbeddings):
153
+ init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
152
154
 
153
155
 
154
156
  @auto_docstring
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...cache_utils import Cache, DynamicCache
26
27
  from ...generation import GenerationMixin
27
28
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
@@ -187,6 +188,13 @@ class CTRLPreTrainedModel(PreTrainedModel):
187
188
  config: CTRLConfig
188
189
  base_model_prefix = "transformer"
189
190
 
191
+ def _init_weights(self, module):
192
+ super()._init_weights(module)
193
+ if isinstance(module, CTRLModel):
194
+ init.copy_(
195
+ module.pos_encoding, positional_encoding(module.config.n_positions, module.d_model_size, torch.float)
196
+ )
197
+
190
198
 
191
199
  @auto_docstring
192
200
  class CTRLModel(CTRLPreTrainedModel):
@@ -196,7 +204,9 @@ class CTRLModel(CTRLPreTrainedModel):
196
204
  self.d_model_size = config.n_embd
197
205
  self.num_layers = config.n_layer
198
206
 
199
- self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
207
+ self.register_buffer(
208
+ "pos_encoding", positional_encoding(config.n_positions, self.d_model_size, torch.float), persistent=False
209
+ )
200
210
 
201
211
  self.w = nn.Embedding(config.vocab_size, config.n_embd)
202
212
 
@@ -470,7 +480,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
470
480
  attentions=transformer_outputs.attentions,
471
481
  )
472
482
 
473
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
483
+ def prepare_inputs_for_generation(
484
+ self, input_ids, past_key_values=None, use_cache=None, is_first_iteration=False, **kwargs
485
+ ):
474
486
  # Overwritten -- inputs_embeds not working properly
475
487
 
476
488
  # only last tokens for inputs_ids if past is defined in kwargs
@@ -534,6 +546,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
534
546
  output_attentions: Optional[bool] = None,
535
547
  output_hidden_states: Optional[bool] = None,
536
548
  return_dict: Optional[bool] = None,
549
+ **kwargs,
537
550
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
538
551
  r"""
539
552
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -497,9 +497,13 @@ class CvtPreTrainedModel(PreTrainedModel):
497
497
  init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
498
498
  if module.bias is not None:
499
499
  init.zeros_(module.bias)
500
- elif isinstance(module, nn.LayerNorm):
500
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
501
501
  init.zeros_(module.bias)
502
502
  init.ones_(module.weight)
503
+ if getattr(module, "running_mean", None) is not None:
504
+ init.zeros_(module.running_mean)
505
+ init.ones_(module.running_var)
506
+ init.zeros_(module.num_batches_tracked)
503
507
  elif isinstance(module, CvtStage):
504
508
  if self.config.cls_token[module.stage]:
505
509
  init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
@@ -523,6 +527,7 @@ class CvtModel(CvtPreTrainedModel):
523
527
  pixel_values: Optional[torch.Tensor] = None,
524
528
  output_hidden_states: Optional[bool] = None,
525
529
  return_dict: Optional[bool] = None,
530
+ **kwargs,
526
531
  ) -> Union[tuple, BaseModelOutputWithCLSToken]:
527
532
  output_hidden_states = (
528
533
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -577,6 +582,7 @@ class CvtForImageClassification(CvtPreTrainedModel):
577
582
  labels: Optional[torch.Tensor] = None,
578
583
  output_hidden_states: Optional[bool] = None,
579
584
  return_dict: Optional[bool] = None,
585
+ **kwargs,
580
586
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
581
587
  r"""
582
588
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -28,7 +28,7 @@ from torch import nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
34
  from ...modeling_layers import GradientCheckpointingLayer
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.generic import check_model_inputs
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
41
  from .configuration_cwm import CwmConfig
42
42
 
43
43
 
@@ -58,7 +58,7 @@ class CwmRotaryEmbedding(nn.Module):
58
58
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
59
59
 
60
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
- self.original_inv_freq = inv_freq
61
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
62
62
 
63
63
  @staticmethod
64
64
  def compute_default_rope_parameters(
@@ -97,7 +97,7 @@ class CwmRotaryEmbedding(nn.Module):
97
97
  position_ids_expanded = position_ids[:, None, :].float()
98
98
 
99
99
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
100
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
100
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
101
101
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
102
102
  emb = torch.cat((freqs, freqs), dim=-1)
103
103
  cos = emb.cos() * self.attention_scaling
@@ -179,6 +179,7 @@ def eager_attention_forward(
179
179
  return attn_output, attn_weights
180
180
 
181
181
 
182
+ @use_kernelized_func(apply_rotary_pos_emb)
182
183
  class CwmAttention(nn.Module):
183
184
  """Multi-headed attention from 'Attention Is All You Need' paper"""
184
185
 
@@ -196,7 +197,6 @@ class CwmAttention(nn.Module):
196
197
  self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
197
198
  self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
198
199
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
199
- self.rotary_fn = apply_rotary_pos_emb
200
200
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
201
201
 
202
202
  def forward(
@@ -47,7 +47,7 @@ class DFineConfig(PreTrainedConfig):
47
47
  The epsilon used by the layer normalization layers.
48
48
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
49
49
  The epsilon used by the batch normalization layers.
50
- backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
50
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
51
51
  The configuration of the backbone model.
52
52
  backbone (`str`, *optional*):
53
53
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -288,8 +288,7 @@ class DFineConfig(PreTrainedConfig):
288
288
  )
289
289
  backbone_model_type = "hgnet_v2"
290
290
  config_class = CONFIG_MAPPING[backbone_model_type]
291
- # this will map it to RTDetrResNetConfig
292
- # note: we can instead create HGNetV2Config
291
+ # this will map it to HGNetV2Config
293
292
  # and we would need to create HGNetV2Backbone
294
293
  backbone_config = config_class(
295
294
  num_channels=3,
@@ -395,8 +394,8 @@ class DFineConfig(PreTrainedConfig):
395
394
  raise ValueError(
396
395
  f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
397
396
  )
397
+
398
398
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
399
- self.tie_encoder_decoder = True
400
399
 
401
400
 
402
401
  __all__ = ["DFineConfig"]
@@ -483,6 +483,9 @@ class DFinePreTrainedModel(PreTrainedModel):
483
483
  init.constant_(module.attention_weights.weight, 0.0)
484
484
  init.constant_(module.attention_weights.bias, 0.0)
485
485
 
486
+ num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
487
+ init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
488
+
486
489
  if isinstance(module, DFineModel):
487
490
  prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
488
491
  bias = float(-math.log((1 - prior_prob) / prior_prob))
@@ -493,6 +496,10 @@ class DFinePreTrainedModel(PreTrainedModel):
493
496
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
494
497
  if module.bias is not None:
495
498
  init.zeros_(module.bias)
499
+ if getattr(module, "running_mean", None) is not None:
500
+ init.zeros_(module.running_mean)
501
+ init.ones_(module.running_var)
502
+ init.zeros_(module.num_batches_tracked)
496
503
 
497
504
  if isinstance(module, DFineGate):
498
505
  bias = float(-math.log((1 - 0.5) / 0.5))
@@ -681,6 +688,7 @@ class DFineDecoder(DFinePreTrainedModel):
681
688
  memory_mask=None,
682
689
  output_attentions=None,
683
690
  return_dict=None,
691
+ **kwargs,
684
692
  ) -> DFineDecoderOutput:
685
693
  r"""
686
694
  Args:
@@ -837,6 +845,45 @@ class DFineDecoder(DFinePreTrainedModel):
837
845
  )
838
846
 
839
847
 
848
+ class DFineFrozenBatchNorm2d(nn.Module):
849
+ """
850
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
851
+
852
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
853
+ torchvision.models.resnet[18,34,50,101] produce nans.
854
+ """
855
+
856
+ def __init__(self, n):
857
+ super().__init__()
858
+ self.register_buffer("weight", torch.ones(n))
859
+ self.register_buffer("bias", torch.zeros(n))
860
+ self.register_buffer("running_mean", torch.zeros(n))
861
+ self.register_buffer("running_var", torch.ones(n))
862
+
863
+ def _load_from_state_dict(
864
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
865
+ ):
866
+ num_batches_tracked_key = prefix + "num_batches_tracked"
867
+ if num_batches_tracked_key in state_dict:
868
+ del state_dict[num_batches_tracked_key]
869
+
870
+ super()._load_from_state_dict(
871
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
872
+ )
873
+
874
+ def forward(self, x):
875
+ # move reshapes to the beginning
876
+ # to make it user-friendly
877
+ weight = self.weight.reshape(1, -1, 1, 1)
878
+ bias = self.bias.reshape(1, -1, 1, 1)
879
+ running_var = self.running_var.reshape(1, -1, 1, 1)
880
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
881
+ epsilon = 1e-5
882
+ scale = weight * (running_var + epsilon).rsqrt()
883
+ bias = bias - running_mean * scale
884
+ return x * scale + bias
885
+
886
+
840
887
  @dataclass
841
888
  @auto_docstring(
842
889
  custom_intro="""
@@ -895,45 +942,6 @@ class DFineModelOutput(ModelOutput):
895
942
  denoising_meta_values: Optional[dict] = None
896
943
 
897
944
 
898
- class DFineFrozenBatchNorm2d(nn.Module):
899
- """
900
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
901
-
902
- Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
903
- torchvision.models.resnet[18,34,50,101] produce nans.
904
- """
905
-
906
- def __init__(self, n):
907
- super().__init__()
908
- self.register_buffer("weight", torch.ones(n))
909
- self.register_buffer("bias", torch.zeros(n))
910
- self.register_buffer("running_mean", torch.zeros(n))
911
- self.register_buffer("running_var", torch.ones(n))
912
-
913
- def _load_from_state_dict(
914
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
915
- ):
916
- num_batches_tracked_key = prefix + "num_batches_tracked"
917
- if num_batches_tracked_key in state_dict:
918
- del state_dict[num_batches_tracked_key]
919
-
920
- super()._load_from_state_dict(
921
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
922
- )
923
-
924
- def forward(self, x):
925
- # move reshapes to the beginning
926
- # to make it user-friendly
927
- weight = self.weight.reshape(1, -1, 1, 1)
928
- bias = self.bias.reshape(1, -1, 1, 1)
929
- running_var = self.running_var.reshape(1, -1, 1, 1)
930
- running_mean = self.running_mean.reshape(1, -1, 1, 1)
931
- epsilon = 1e-5
932
- scale = weight * (running_var + epsilon).rsqrt()
933
- bias = bias - running_mean * scale
934
- return x * scale + bias
935
-
936
-
937
945
  def replace_batch_norm(model):
938
946
  r"""
939
947
  Recursively replace all `torch.nn.BatchNorm2d` with `DFineFrozenBatchNorm2d`.
@@ -1247,6 +1255,7 @@ class DFineModel(DFinePreTrainedModel):
1247
1255
  output_attentions: Optional[bool] = None,
1248
1256
  output_hidden_states: Optional[bool] = None,
1249
1257
  return_dict: Optional[bool] = None,
1258
+ **kwargs,
1250
1259
  ) -> Union[tuple[torch.FloatTensor], DFineModelOutput]:
1251
1260
  r"""
1252
1261
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -33,6 +33,7 @@ from ..rt_detr.modeling_rt_detr import (
33
33
  RTDetrDecoderOutput,
34
34
  RTDetrEncoder,
35
35
  RTDetrForObjectDetection,
36
+ RTDetrFrozenBatchNorm2d,
36
37
  RTDetrHybridEncoder,
37
38
  RTDetrMLPPredictionHead,
38
39
  RTDetrModel,
@@ -66,7 +67,7 @@ class DFineConfig(PreTrainedConfig):
66
67
  The epsilon used by the layer normalization layers.
67
68
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
68
69
  The epsilon used by the batch normalization layers.
69
- backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
70
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
70
71
  The configuration of the backbone model.
71
72
  backbone (`str`, *optional*):
72
73
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -307,8 +308,7 @@ class DFineConfig(PreTrainedConfig):
307
308
  )
308
309
  backbone_model_type = "hgnet_v2"
309
310
  config_class = CONFIG_MAPPING[backbone_model_type]
310
- # this will map it to RTDetrResNetConfig
311
- # note: we can instead create HGNetV2Config
311
+ # this will map it to HGNetV2Config
312
312
  # and we would need to create HGNetV2Backbone
313
313
  backbone_config = config_class(
314
314
  num_channels=3,
@@ -414,8 +414,8 @@ class DFineConfig(PreTrainedConfig):
414
414
  raise ValueError(
415
415
  f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
416
416
  )
417
+
417
418
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
418
- self.tie_encoder_decoder = True
419
419
 
420
420
 
421
421
  class DFineMultiscaleDeformableAttention(nn.Module):
@@ -628,6 +628,9 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
628
628
  init.constant_(module.attention_weights.weight, 0.0)
629
629
  init.constant_(module.attention_weights.bias, 0.0)
630
630
 
631
+ num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
632
+ init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
633
+
631
634
  if isinstance(module, DFineModel):
632
635
  prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
633
636
  bias = float(-math.log((1 - prior_prob) / prior_prob))
@@ -638,6 +641,10 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
638
641
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
639
642
  if module.bias is not None:
640
643
  init.zeros_(module.bias)
644
+ if getattr(module, "running_mean", None) is not None:
645
+ init.zeros_(module.running_mean)
646
+ init.ones_(module.running_var)
647
+ init.zeros_(module.num_batches_tracked)
641
648
 
642
649
  if isinstance(module, DFineGate):
643
650
  bias = float(-math.log((1 - 0.5) / 0.5))
@@ -726,6 +733,7 @@ class DFineDecoder(RTDetrDecoder):
726
733
  memory_mask=None,
727
734
  output_attentions=None,
728
735
  return_dict=None,
736
+ **kwargs,
729
737
  ) -> DFineDecoderOutput:
730
738
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
731
739
  output_hidden_states = (
@@ -850,6 +858,10 @@ class DFineDecoder(RTDetrDecoder):
850
858
  )
851
859
 
852
860
 
861
+ class DFineFrozenBatchNorm2d(RTDetrFrozenBatchNorm2d):
862
+ pass
863
+
864
+
853
865
  class DFineModel(RTDetrModel):
854
866
  def __init__(self, config: DFineConfig):
855
867
  super().__init__(config)
@@ -37,7 +37,7 @@ class DabDetrConfig(PreTrainedConfig):
37
37
  use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
38
  Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
39
  API.
40
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
40
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
41
41
  The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
42
  case it will default to `ResNetConfig()`.
43
43
  backbone (`str`, *optional*, defaults to `"resnet50"`):
@@ -255,8 +255,8 @@ class DabDetrConfig(PreTrainedConfig):
255
255
  self.temperature_height = temperature_height
256
256
  self.sine_position_embedding_scale = sine_position_embedding_scale
257
257
  self.initializer_bias_prior_prob = initializer_bias_prior_prob
258
+
258
259
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
259
- self.tie_encoder_decoder = True # weights have to be tied for this model
260
260
 
261
261
 
262
262
  __all__ = ["DabDetrConfig"]
@@ -826,7 +826,7 @@ class DabDetrPreTrainedModel(PreTrainedModel):
826
826
  init.zeros_(module.q_linear.bias)
827
827
  init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
828
828
  init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
829
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
829
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
830
830
  init.normal_(module.weight, mean=0.0, std=std)
831
831
  if module.bias is not None:
832
832
  init.zeros_(module.bias)
@@ -886,6 +886,7 @@ class DabDetrEncoder(DabDetrPreTrainedModel):
886
886
  output_attentions: Optional[bool] = None,
887
887
  output_hidden_states: Optional[bool] = None,
888
888
  return_dict: Optional[bool] = None,
889
+ **kwargs,
889
890
  ):
890
891
  r"""
891
892
  Args:
@@ -1016,6 +1017,7 @@ class DabDetrDecoder(DabDetrPreTrainedModel):
1016
1017
  output_attentions: Optional[bool] = None,
1017
1018
  output_hidden_states: Optional[bool] = None,
1018
1019
  return_dict: Optional[bool] = None,
1020
+ **kwargs,
1019
1021
  ):
1020
1022
  r"""
1021
1023
  Args:
@@ -1222,6 +1224,7 @@ class DabDetrModel(DabDetrPreTrainedModel):
1222
1224
  output_attentions: Optional[bool] = None,
1223
1225
  output_hidden_states: Optional[bool] = None,
1224
1226
  return_dict: Optional[bool] = None,
1227
+ **kwargs,
1225
1228
  ) -> Union[tuple[torch.FloatTensor], DabDetrModelOutput]:
1226
1229
  r"""
1227
1230
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1469,6 +1472,7 @@ class DabDetrForObjectDetection(DabDetrPreTrainedModel):
1469
1472
  output_attentions: Optional[bool] = None,
1470
1473
  output_hidden_states: Optional[bool] = None,
1471
1474
  return_dict: Optional[bool] = None,
1475
+ **kwargs,
1472
1476
  ) -> Union[tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]:
1473
1477
  r"""
1474
1478
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):