transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
 
17
- import collections.abc
18
17
  import math
19
- from collections.abc import Callable
18
+ from collections.abc import Callable, Iterable
20
19
  from dataclasses import dataclass
21
20
  from typing import Optional, Union
22
21
 
@@ -40,7 +39,7 @@ from ...modeling_outputs import (
40
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
40
  from ...processing_utils import Unpack
42
41
  from ...pytorch_utils import compile_compatible_method_lru_cache
43
- from ...utils import auto_docstring
42
+ from ...utils import auto_docstring, logging
44
43
  from ...utils.generic import TransformersKwargs, check_model_inputs
45
44
  from ..auto import AutoModel
46
45
  from .configuration_sam3 import (
@@ -54,6 +53,9 @@ from .configuration_sam3 import (
54
53
  )
55
54
 
56
55
 
56
+ logger = logging.get_logger(__name__)
57
+
58
+
57
59
  @dataclass
58
60
  @auto_docstring
59
61
  class Sam3VisionEncoderOutput(ModelOutput):
@@ -123,8 +125,8 @@ class Sam3DETRDecoderOutput(ModelOutput):
123
125
  Decoder hidden states from all layers.
124
126
  reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
125
127
  Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
126
- presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size)`, *optional*):
127
- Presence logits from all decoder layers (None if using instance queries).
128
+ presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
129
+ Presence logits from all decoder layers indicating object presence confidence.
128
130
  hidden_states (`tuple[torch.FloatTensor]`, *optional*):
129
131
  Tuple of hidden states from all decoder layers.
130
132
  attentions (`tuple[torch.FloatTensor]`, *optional*):
@@ -133,7 +135,7 @@ class Sam3DETRDecoderOutput(ModelOutput):
133
135
 
134
136
  intermediate_hidden_states: torch.FloatTensor = None
135
137
  reference_boxes: torch.FloatTensor = None
136
- presence_logits: Optional[torch.FloatTensor] = None
138
+ presence_logits: torch.FloatTensor = None
137
139
  hidden_states: Optional[tuple[torch.FloatTensor]] = None
138
140
  attentions: Optional[tuple[torch.FloatTensor]] = None
139
141
 
@@ -372,6 +374,19 @@ class Sam3Attention(nn.Module):
372
374
  if self.config._attn_implementation != "eager":
373
375
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
374
376
 
377
+ if (
378
+ "flash" in self.config._attn_implementation
379
+ and attention_mask is not None
380
+ and attention_mask.dtype != torch.bool
381
+ ):
382
+ # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
383
+ # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
384
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
385
+ logger.warning_once(
386
+ "Sam3Attention: falling back to SDPA for relative-position cross-attention because "
387
+ "Flash Attention does not support additive bias masks."
388
+ )
389
+
375
390
  attn_output, attn_weights = attention_interface(
376
391
  self,
377
392
  query,
@@ -531,8 +546,8 @@ class Sam3ViTPatchEmbeddings(nn.Module):
531
546
  image_size, patch_size = config.pretrain_image_size, config.patch_size
532
547
  num_channels, hidden_size = config.num_channels, config.hidden_size
533
548
 
534
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
535
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
549
+ image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
550
+ patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
536
551
  num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
537
552
  self.image_size = image_size
538
553
  self.patch_size = patch_size
@@ -542,7 +557,7 @@ class Sam3ViTPatchEmbeddings(nn.Module):
542
557
  self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
543
558
 
544
559
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
545
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
560
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
546
561
  return embeddings
547
562
 
548
563
 
@@ -938,6 +953,7 @@ class Sam3FPNLayer(nn.Module):
938
953
  self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
939
954
 
940
955
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
956
+ hidden_states = hidden_states.to(self.proj1.weight.dtype)
941
957
  for layer in self.scale_layers:
942
958
  hidden_states = layer(hidden_states)
943
959
 
@@ -1253,7 +1269,7 @@ class Sam3DetrEncoderLayer(nn.Module):
1253
1269
  vision_feats: Tensor,
1254
1270
  prompt_feats: Tensor,
1255
1271
  vision_pos_encoding: Tensor,
1256
- prompt_mask: Tensor,
1272
+ prompt_cross_attn_mask: Optional[Tensor] = None,
1257
1273
  **kwargs: Unpack[TransformersKwargs],
1258
1274
  ):
1259
1275
  """
@@ -1263,7 +1279,7 @@ class Sam3DetrEncoderLayer(nn.Module):
1263
1279
  vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
1264
1280
  prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
1265
1281
  vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
1266
- prompt_mask: Padding mask for prompts [batch_size, text_len] where True=valid, False=padding
1282
+ prompt_cross_attn_mask: Cross-attention mask for prompt features
1267
1283
 
1268
1284
  Returns:
1269
1285
  Updated vision features [batch_size, vision_len, hidden_size]
@@ -1284,15 +1300,6 @@ class Sam3DetrEncoderLayer(nn.Module):
1284
1300
  residual = hidden_states
1285
1301
  hidden_states = self.layer_norm2(hidden_states)
1286
1302
 
1287
- prompt_cross_attn_mask = None
1288
- if prompt_mask is not None:
1289
- prompt_cross_attn_mask = create_bidirectional_mask(
1290
- config=self.config,
1291
- input_embeds=hidden_states,
1292
- attention_mask=prompt_mask,
1293
- encoder_hidden_states=prompt_feats,
1294
- )
1295
-
1296
1303
  hidden_states, _ = self.cross_attn(
1297
1304
  query=hidden_states,
1298
1305
  key=prompt_feats,
@@ -1412,13 +1419,22 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
1412
1419
  spatial_shapes,
1413
1420
  ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
1414
1421
 
1422
+ prompt_cross_attn_mask = None
1423
+ if text_mask is not None:
1424
+ prompt_cross_attn_mask = create_bidirectional_mask(
1425
+ config=self.config,
1426
+ input_embeds=features_flattened,
1427
+ attention_mask=text_mask,
1428
+ encoder_hidden_states=text_features,
1429
+ )
1430
+
1415
1431
  hidden_states = features_flattened
1416
1432
  for layer in self.layers:
1417
1433
  hidden_states = layer(
1418
1434
  hidden_states,
1419
1435
  prompt_feats=text_features,
1420
1436
  vision_pos_encoding=pos_embeds_flattened,
1421
- prompt_mask=text_mask,
1437
+ prompt_cross_attn_mask=prompt_cross_attn_mask,
1422
1438
  **kwargs,
1423
1439
  )
1424
1440
  return Sam3DETREncoderOutput(
@@ -1484,31 +1500,27 @@ class Sam3DetrDecoderLayer(nn.Module):
1484
1500
  text_features: torch.Tensor,
1485
1501
  vision_features: torch.Tensor,
1486
1502
  vision_pos_encoding: torch.Tensor,
1487
- text_mask: Optional[torch.Tensor] = None,
1503
+ text_cross_attn_mask: Optional[torch.Tensor] = None,
1488
1504
  vision_cross_attn_mask: Optional[torch.Tensor] = None,
1489
- presence_token: Optional[torch.Tensor] = None,
1490
1505
  **kwargs: Unpack[TransformersKwargs],
1491
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1506
+ ) -> torch.Tensor:
1492
1507
  """
1493
1508
  Forward pass for decoder layer.
1494
1509
 
1495
1510
  Args:
1496
- hidden_states: Query features [batch_size, num_queries, hidden_size]
1511
+ hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
1497
1512
  query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
1498
1513
  text_features: Text features [batch_size, seq_len, hidden_size]
1499
1514
  vision_features: Vision features [batch_size, height*width, hidden_size]
1500
1515
  vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
1501
- text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
1502
- vision_cross_attn_mask: Vision cross-attention mask [batch_size, num_heads, num_queries, height*width]
1503
- presence_token: Optional presence token [batch_size, 1, hidden_size]
1516
+ text_cross_attn_mask: Text cross-attention mask
1517
+ vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
1504
1518
 
1505
1519
  Returns:
1506
- Tuple of (updated hidden states, updated presence token)
1520
+ Updated hidden states (including presence token at position 0)
1507
1521
  """
1508
- # Concatenate presence token if provided
1509
- if presence_token is not None:
1510
- hidden_states = torch.cat([presence_token, hidden_states], dim=1)
1511
- query_pos = torch.cat([torch.zeros_like(presence_token), query_pos], dim=1)
1522
+ # Prepend zeros to query_pos for presence token
1523
+ query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
1512
1524
 
1513
1525
  # Self-attention with query position encoding
1514
1526
  residual = hidden_states
@@ -1527,15 +1539,6 @@ class Sam3DetrDecoderLayer(nn.Module):
1527
1539
  residual = hidden_states
1528
1540
  query_with_pos = hidden_states + query_pos
1529
1541
 
1530
- text_cross_attn_mask = None
1531
- if text_mask is not None:
1532
- text_cross_attn_mask = create_bidirectional_mask(
1533
- config=self.config,
1534
- input_embeds=hidden_states,
1535
- attention_mask=text_mask,
1536
- encoder_hidden_states=text_features,
1537
- )
1538
-
1539
1542
  attn_output, _ = self.text_cross_attn(
1540
1543
  query=query_with_pos,
1541
1544
  key=text_features,
@@ -1546,20 +1549,6 @@ class Sam3DetrDecoderLayer(nn.Module):
1546
1549
  hidden_states = residual + self.text_cross_attn_dropout(attn_output)
1547
1550
  hidden_states = self.text_cross_attn_layer_norm(hidden_states)
1548
1551
 
1549
- # Expand vision cross-attention mask for presence token if needed
1550
- combined_vision_mask = vision_cross_attn_mask
1551
- if presence_token is not None and combined_vision_mask is not None:
1552
- batch_size, num_heads = combined_vision_mask.shape[:2]
1553
- presence_mask = torch.zeros(
1554
- batch_size,
1555
- num_heads,
1556
- 1,
1557
- combined_vision_mask.shape[-1],
1558
- device=combined_vision_mask.device,
1559
- dtype=combined_vision_mask.dtype,
1560
- )
1561
- combined_vision_mask = torch.cat([presence_mask, combined_vision_mask], dim=2)
1562
-
1563
1552
  # Vision cross-attention: queries attend to vision features (with RPB)
1564
1553
  residual = hidden_states
1565
1554
  query_with_pos = hidden_states + query_pos
@@ -1568,7 +1557,7 @@ class Sam3DetrDecoderLayer(nn.Module):
1568
1557
  query=query_with_pos,
1569
1558
  key=key_with_pos,
1570
1559
  value=vision_features,
1571
- attention_mask=combined_vision_mask,
1560
+ attention_mask=vision_cross_attn_mask,
1572
1561
  **kwargs,
1573
1562
  )
1574
1563
  hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
@@ -1580,13 +1569,7 @@ class Sam3DetrDecoderLayer(nn.Module):
1580
1569
  hidden_states = residual + self.mlp_dropout(hidden_states)
1581
1570
  hidden_states = self.mlp_layer_norm(hidden_states)
1582
1571
 
1583
- # Extract presence token if it was added
1584
- presence_token_out = None
1585
- if presence_token is not None:
1586
- presence_token_out = hidden_states[:, :1]
1587
- hidden_states = hidden_states[:, 1:]
1588
-
1589
- return hidden_states, presence_token_out
1572
+ return hidden_states
1590
1573
 
1591
1574
 
1592
1575
  class Sam3DetrDecoder(Sam3PreTrainedModel):
@@ -1715,11 +1698,23 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1715
1698
  """
1716
1699
  batch_size = vision_features.shape[0]
1717
1700
 
1718
- hidden_states = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
1701
+ query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
1719
1702
  reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
1720
1703
  reference_boxes = reference_boxes.sigmoid()
1721
1704
  presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
1722
1705
 
1706
+ # Concatenate presence token with query embeddings
1707
+ hidden_states = torch.cat([presence_token, query_embeds], dim=1)
1708
+
1709
+ text_cross_attn_mask = None
1710
+ if text_mask is not None:
1711
+ text_cross_attn_mask = create_bidirectional_mask(
1712
+ config=self.config,
1713
+ input_embeds=hidden_states,
1714
+ attention_mask=text_mask,
1715
+ encoder_hidden_states=text_features,
1716
+ )
1717
+
1723
1718
  intermediate_outputs = []
1724
1719
  intermediate_boxes = [reference_boxes]
1725
1720
  intermediate_presence_logits = []
@@ -1734,43 +1729,45 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1734
1729
  vision_cross_attn_mask = None
1735
1730
  if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
1736
1731
  spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
1737
- vision_cross_attn_mask = self._get_rpb_matrix(reference_boxes, spatial_shape)
1732
+ rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
1733
+ # Prepend zeros row for presence token (it attends to all vision tokens equally)
1734
+ vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
1738
1735
 
1739
- hidden_states, presence_token = layer(
1736
+ hidden_states = layer(
1740
1737
  hidden_states,
1741
1738
  query_pos=query_pos,
1742
1739
  text_features=text_features,
1743
1740
  vision_features=vision_features,
1744
1741
  vision_pos_encoding=vision_pos_encoding,
1745
- text_mask=text_mask,
1742
+ text_cross_attn_mask=text_cross_attn_mask,
1746
1743
  vision_cross_attn_mask=vision_cross_attn_mask,
1747
- presence_token=presence_token,
1748
1744
  **kwargs,
1749
1745
  )
1750
1746
 
1747
+ # Extract query hidden states (without presence token) for box refinement
1748
+ query_hidden_states = hidden_states[:, 1:]
1749
+
1751
1750
  # Box refinement: predict delta and update reference boxes
1752
1751
  reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
1753
- delta_boxes = self.box_head(self.output_layer_norm(hidden_states))
1752
+ delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
1754
1753
  new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
1755
1754
  reference_boxes = new_reference_boxes.detach()
1756
1755
 
1757
- intermediate_outputs.append(self.output_layer_norm(hidden_states))
1756
+ intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
1758
1757
  intermediate_boxes.append(new_reference_boxes)
1759
1758
 
1760
1759
  # Process presence token
1761
- if presence_token is not None:
1762
- presence_logits = self.presence_head(self.presence_layer_norm(presence_token)).squeeze(-1)
1763
- presence_logits = presence_logits.clamp(
1764
- min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
1765
- )
1766
- intermediate_presence_logits.append(presence_logits)
1760
+ presence_hidden = hidden_states[:, :1]
1761
+ presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
1762
+ presence_logits = presence_logits.clamp(
1763
+ min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
1764
+ )
1765
+ intermediate_presence_logits.append(presence_logits)
1767
1766
 
1768
1767
  # Stack outputs from all layers
1769
1768
  intermediate_outputs = torch.stack(intermediate_outputs)
1770
1769
  intermediate_boxes = torch.stack(intermediate_boxes[:-1])
1771
- intermediate_presence_logits = (
1772
- torch.stack(intermediate_presence_logits) if intermediate_presence_logits else None
1773
- )
1770
+ intermediate_presence_logits = torch.stack(intermediate_presence_logits)
1774
1771
 
1775
1772
  return Sam3DETRDecoderOutput(
1776
1773
  intermediate_hidden_states=intermediate_outputs,
@@ -107,7 +107,12 @@ class Sam3TrackerFeedForward(nn.Module):
107
107
  return hidden_states
108
108
 
109
109
 
110
- @auto_docstring
110
+ @auto_docstring(
111
+ custom_intro="""
112
+ Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
113
+ input points and labels, boxes, or masks.
114
+ """
115
+ )
111
116
  class Sam3TrackerPreTrainedModel(PreTrainedModel):
112
117
  config_class = Sam3TrackerConfig
113
118
  base_model_prefix = "sam3_tracker"
@@ -136,7 +136,12 @@ class Sam3TrackerFeedForward(Sam2FeedForward):
136
136
  pass
137
137
 
138
138
 
139
- @auto_docstring
139
+ @auto_docstring(
140
+ custom_intro="""
141
+ Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
142
+ input points and labels, boxes, or masks.
143
+ """
144
+ )
140
145
  class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
141
146
  @torch.no_grad()
142
147
  def _init_weights(self, module):
@@ -1719,6 +1719,7 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
1719
1719
  frame: Optional[torch.Tensor] = None,
1720
1720
  reverse: bool = False,
1721
1721
  run_mem_encoder: bool = True,
1722
+ **kwargs,
1722
1723
  ) -> Sam3TrackerVideoSegmentationOutput:
1723
1724
  r"""
1724
1725
  inference_session (`Sam3TrackerVideoInferenceSession`):
@@ -1697,6 +1697,7 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
1697
1697
  frame_idx: Optional[int] = None,
1698
1698
  frame: Optional[torch.Tensor] = None,
1699
1699
  reverse: bool = False,
1700
+ **kwargs,
1700
1701
  ):
1701
1702
  r"""
1702
1703
  inference_session (`Sam3VideoInferenceSession`):
@@ -1770,6 +1770,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
1770
1770
  output_hidden_states: Optional[bool] = None,
1771
1771
  return_dict: Optional[bool] = None,
1772
1772
  cache_position: Optional[torch.Tensor] = None,
1773
+ **kwargs,
1773
1774
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
1774
1775
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1775
1776
  output_hidden_states = (
@@ -1914,6 +1915,7 @@ class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel):
1914
1915
  output_hidden_states: Optional[bool] = None,
1915
1916
  return_dict: Optional[bool] = None,
1916
1917
  cache_position: Optional[torch.Tensor] = None,
1918
+ **kwargs,
1917
1919
  ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
1918
1920
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1919
1921
  output_hidden_states = (
@@ -2035,6 +2037,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
2035
2037
  output_hidden_states: Optional[bool] = None,
2036
2038
  return_dict: Optional[bool] = None,
2037
2039
  cache_position: Optional[torch.Tensor] = None,
2040
+ **kwargs,
2038
2041
  ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
2039
2042
  r"""
2040
2043
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2354,7 +2357,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
2354
2357
  return input_lengths
2355
2358
 
2356
2359
  def forward(
2357
- self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor
2360
+ self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor, **kwargs
2358
2361
  ) -> tuple[torch.Tensor]:
2359
2362
  """
2360
2363
  Args:
@@ -2996,6 +2999,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
2996
2999
  output_hidden_states: Optional[bool] = None,
2997
3000
  return_dict: Optional[bool] = None,
2998
3001
  cache_position: Optional[torch.Tensor] = None,
3002
+ **kwargs,
2999
3003
  ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
3000
3004
  r"""
3001
3005
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -60,7 +60,7 @@ class SeamlessM4TTokenizer(TokenizersBackend):
60
60
  Args:
61
61
  vocab (`list` or `dict`, *optional*):
62
62
  List of (token, score) tuples or dict mapping tokens to indices. If not provided, uses default vocab.
63
- merges (`list`, *optional*):
63
+ merges (`str` or `list`, *optional*):
64
64
  List of merge rules for BPE model. If not provided, uses empty list.
65
65
  bos_token (`str`, *optional*, defaults to `"<s>"`):
66
66
  The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
@@ -104,15 +104,15 @@ class SeamlessM4TTokenizer(TokenizersBackend):
104
104
 
105
105
  vocab_files_names = VOCAB_FILES_NAMES
106
106
  model_input_names = ["input_ids", "attention_mask"]
107
- slow_tokenizer_class = None
107
+ model = BPE
108
108
 
109
- prefix_tokens: list[int] = []
110
- suffix_tokens: list[int] = []
109
+ prefix_tokens: list[int] = None
110
+ suffix_tokens: list[int] = None
111
111
 
112
112
  def __init__(
113
113
  self,
114
- vocab: Optional[list] = None,
115
- merges: Optional[list] = None,
114
+ vocab: Optional[Union[str, dict[str, int]]] = None,
115
+ merges: Optional[Union[str, list[str]]] = None,
116
116
  bos_token="<s>",
117
117
  eos_token="</s>",
118
118
  sep_token="</s>",
@@ -126,59 +126,14 @@ class SeamlessM4TTokenizer(TokenizersBackend):
126
126
  vocab_file=None,
127
127
  **kwargs,
128
128
  ):
129
- if vocab is None:
130
- vocab = {
131
- str(pad_token): 0,
132
- str(unk_token): 1,
133
- str(bos_token): 2,
134
- str(eos_token): 3,
135
- }
136
-
137
- # Process vocab - SeamlessM4T uses fairseq vocab alignment: <pad>=0, <unk>=1, <s>=2, </s>=3, then SPM pieces[3:]
138
- if isinstance(vocab, list):
139
- # Convert list of (token, score) tuples to dict {token: idx}
140
- # Check if vocab is already in SeamlessM4T order (pad, unk, s, /s) or tokenizer.json order (unk, s, /s, ...)
141
- first_tokens = [str(item[0]) if isinstance(item, (list, tuple)) else str(item) for item in vocab[:4]]
142
- is_seamless_order = (
143
- len(first_tokens) >= 4
144
- and first_tokens[0] == str(pad_token)
145
- and first_tokens[1] == str(unk_token)
146
- and first_tokens[2] == str(bos_token)
147
- and first_tokens[3] == str(eos_token)
148
- )
149
-
150
- if is_seamless_order:
151
- # Already in correct order, use list index directly as token ID
152
- vocab_dict = {}
153
- for idx, item in enumerate(vocab):
154
- token = str(item[0]) if isinstance(item, (list, tuple)) else str(item)
155
- vocab_dict[token] = idx
156
- self._vocab = vocab_dict
157
- else:
158
- # Reorder to fairseq: <pad>, <unk>, <s>, </s>, ... (rest of vocab)
159
- vocab_dict = {}
160
- vocab_dict[str(pad_token)] = 0
161
- vocab_dict[str(unk_token)] = 1
162
- vocab_dict[str(bos_token)] = 2
163
- vocab_dict[str(eos_token)] = 3
164
-
165
- # Add rest of vocab starting from index 4, skipping tokens we already added
166
- idx = 4
167
- for item in vocab:
168
- token = str(item[0]) if isinstance(item, (list, tuple)) else str(item)
169
- if token not in vocab_dict:
170
- vocab_dict[token] = idx
171
- idx += 1
172
-
173
- self._vocab = vocab_dict
174
- else:
175
- self._vocab = vocab
176
-
177
- if merges is None:
178
- self._merges = []
179
- else:
180
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
129
+ self._vocab = vocab or {
130
+ str(pad_token): 0,
131
+ str(unk_token): 1,
132
+ str(bos_token): 2,
133
+ str(eos_token): 3,
134
+ }
181
135
 
136
+ self._merges = merges or []
182
137
  self._tokenizer = Tokenizer(
183
138
  BPE(
184
139
  vocab=self._vocab,
@@ -216,7 +171,6 @@ class SeamlessM4TTokenizer(TokenizersBackend):
216
171
  kwargs.setdefault("additional_special_tokens", additional_special_tokens)
217
172
 
218
173
  super().__init__(
219
- tokenizer_object=self._tokenizer,
220
174
  bos_token=bos_token,
221
175
  eos_token=eos_token,
222
176
  sep_token=sep_token,
@@ -245,6 +199,20 @@ class SeamlessM4TTokenizer(TokenizersBackend):
245
199
 
246
200
  self.set_tgt_lang_special_tokens(self._tgt_lang)
247
201
 
202
+ @classmethod
203
+ def convert_from_spm_model(cls, vocab, **kwargs):
204
+ """When converting from spm, offset is needed to account for special tokens."""
205
+ _vocab = {
206
+ "<pad>": 0,
207
+ "<unk>": 1,
208
+ "<s>": 2,
209
+ "</s>": 3,
210
+ }
211
+ for i, token in enumerate(list(vocab.keys())):
212
+ _vocab[token] = i + 1 # offset by 1 to account for special tokens
213
+ kwargs["vocab"] = _vocab
214
+ return kwargs
215
+
248
216
  @property
249
217
  def src_lang(self) -> str:
250
218
  return self._src_lang
@@ -1812,6 +1812,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
1812
1812
  output_hidden_states: Optional[bool] = None,
1813
1813
  return_dict: Optional[bool] = None,
1814
1814
  cache_position: Optional[torch.Tensor] = None,
1815
+ **kwargs,
1815
1816
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
1816
1817
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1817
1818
  output_hidden_states = (
@@ -1995,6 +1996,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel):
1995
1996
  output_attentions: Optional[bool] = None,
1996
1997
  output_hidden_states: Optional[bool] = None,
1997
1998
  return_dict: Optional[bool] = None,
1999
+ **kwargs,
1998
2000
  ) -> Union[tuple, SeamlessM4Tv2TextToUnitDecoderOutput]:
1999
2001
  r"""
2000
2002
  Args:
@@ -2122,6 +2124,7 @@ class SeamlessM4Tv2TextToUnitModel(SeamlessM4Tv2PreTrainedModel):
2122
2124
  output_attentions: Optional[bool] = None,
2123
2125
  output_hidden_states: Optional[bool] = None,
2124
2126
  return_dict: Optional[bool] = None,
2127
+ **kwargs,
2125
2128
  ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
2126
2129
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2127
2130
  output_hidden_states = (
@@ -2556,7 +2559,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
2556
2559
 
2557
2560
  # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.forward with SeamlessM4T->SeamlessM4Tv2, spkr_id->speaker_id
2558
2561
  def forward(
2559
- self, input_ids: torch.LongTensor, speaker_id: torch.Tensor, lang_id: torch.Tensor
2562
+ self, input_ids: torch.LongTensor, speaker_id: torch.Tensor, lang_id: torch.Tensor, **kwargs
2560
2563
  ) -> tuple[torch.Tensor]:
2561
2564
  """
2562
2565
  Args:
@@ -3214,6 +3217,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
3214
3217
  output_hidden_states: Optional[bool] = None,
3215
3218
  return_dict: Optional[bool] = None,
3216
3219
  cache_position: Optional[torch.Tensor] = None,
3220
+ **kwargs,
3217
3221
  ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
3218
3222
  r"""
3219
3223
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_seed_oss import SeedOssConfig
45
45
 
46
46
 
@@ -350,7 +350,7 @@ class SeedOssRotaryEmbedding(nn.Module):
350
350
  position_ids_expanded = position_ids[:, None, :].float()
351
351
 
352
352
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
353
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
353
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
354
354
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
355
355
  emb = torch.cat((freqs, freqs), dim=-1)
356
356
  cos = emb.cos() * self.attention_scaling
@@ -434,6 +434,7 @@ class SegformerModel(SegformerPreTrainedModel):
434
434
  output_attentions: Optional[bool] = None,
435
435
  output_hidden_states: Optional[bool] = None,
436
436
  return_dict: Optional[bool] = None,
437
+ **kwargs,
437
438
  ) -> Union[tuple, BaseModelOutput]:
438
439
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
439
440
  output_hidden_states = (
@@ -486,6 +487,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
486
487
  output_attentions: Optional[bool] = None,
487
488
  output_hidden_states: Optional[bool] = None,
488
489
  return_dict: Optional[bool] = None,
490
+ **kwargs,
489
491
  ) -> Union[tuple, SegFormerImageClassifierOutput]:
490
492
  r"""
491
493
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -572,7 +574,7 @@ class SegformerDecodeHead(SegformerPreTrainedModel):
572
574
 
573
575
  self.config = config
574
576
 
575
- def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
577
+ def forward(self, encoder_hidden_states: torch.FloatTensor, **kwargs) -> torch.Tensor:
576
578
  batch_size = encoder_hidden_states[-1].shape[0]
577
579
 
578
580
  all_hidden_states = ()
@@ -627,6 +629,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
627
629
  output_attentions: Optional[bool] = None,
628
630
  output_hidden_states: Optional[bool] = None,
629
631
  return_dict: Optional[bool] = None,
632
+ **kwargs,
630
633
  ) -> Union[tuple, SemanticSegmenterOutput]:
631
634
  r"""
632
635
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -647,6 +647,7 @@ class SegGptModel(SegGptPreTrainedModel):
647
647
  output_attentions: Optional[bool] = None,
648
648
  output_hidden_states: Optional[bool] = None,
649
649
  return_dict: Optional[bool] = None,
650
+ **kwargs,
650
651
  ) -> Union[tuple, SegGptEncoderOutput]:
651
652
  r"""
652
653
  prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@@ -843,6 +844,7 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
843
844
  output_attentions: Optional[bool] = None,
844
845
  output_hidden_states: Optional[bool] = None,
845
846
  return_dict: Optional[bool] = None,
847
+ **kwargs,
846
848
  ) -> Union[tuple, SegGptImageSegmentationOutput]:
847
849
  r"""
848
850
  prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):