transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
 
17
- import collections.abc
18
17
  import math
19
- from collections.abc import Callable
18
+ from collections.abc import Callable, Iterable
20
19
  from dataclasses import dataclass
21
20
  from typing import Optional, Union
22
21
 
@@ -40,7 +39,7 @@ from ...modeling_outputs import (
40
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
40
  from ...processing_utils import Unpack
42
41
  from ...pytorch_utils import compile_compatible_method_lru_cache
43
- from ...utils import auto_docstring
42
+ from ...utils import auto_docstring, logging
44
43
  from ...utils.generic import TransformersKwargs, check_model_inputs
45
44
  from ..auto import AutoModel
46
45
  from .configuration_sam3 import (
@@ -54,6 +53,9 @@ from .configuration_sam3 import (
54
53
  )
55
54
 
56
55
 
56
+ logger = logging.get_logger(__name__)
57
+
58
+
57
59
  @dataclass
58
60
  @auto_docstring
59
61
  class Sam3VisionEncoderOutput(ModelOutput):
@@ -123,8 +125,8 @@ class Sam3DETRDecoderOutput(ModelOutput):
123
125
  Decoder hidden states from all layers.
124
126
  reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
125
127
  Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
126
- presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size)`, *optional*):
127
- Presence logits from all decoder layers (None if using instance queries).
128
+ presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
129
+ Presence logits from all decoder layers indicating object presence confidence.
128
130
  hidden_states (`tuple[torch.FloatTensor]`, *optional*):
129
131
  Tuple of hidden states from all decoder layers.
130
132
  attentions (`tuple[torch.FloatTensor]`, *optional*):
@@ -133,7 +135,7 @@ class Sam3DETRDecoderOutput(ModelOutput):
133
135
 
134
136
  intermediate_hidden_states: torch.FloatTensor = None
135
137
  reference_boxes: torch.FloatTensor = None
136
- presence_logits: Optional[torch.FloatTensor] = None
138
+ presence_logits: torch.FloatTensor = None
137
139
  hidden_states: Optional[tuple[torch.FloatTensor]] = None
138
140
  attentions: Optional[tuple[torch.FloatTensor]] = None
139
141
 
@@ -372,6 +374,19 @@ class Sam3Attention(nn.Module):
372
374
  if self.config._attn_implementation != "eager":
373
375
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
374
376
 
377
+ if (
378
+ "flash" in self.config._attn_implementation
379
+ and attention_mask is not None
380
+ and attention_mask.dtype != torch.bool
381
+ ):
382
+ # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
383
+ # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
384
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
385
+ logger.warning_once(
386
+ "Sam3Attention: falling back to SDPA for relative-position cross-attention because "
387
+ "Flash Attention does not support additive bias masks."
388
+ )
389
+
375
390
  attn_output, attn_weights = attention_interface(
376
391
  self,
377
392
  query,
@@ -402,6 +417,10 @@ class Sam3ViTRotaryEmbedding(nn.Module):
402
417
  # Ensure even dimension for proper axial splitting
403
418
  if dim % 4 != 0:
404
419
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
420
+ self.end_x, self.end_y = end_x, end_y
421
+ self.dim = dim
422
+ self.rope_theta = config.rope_theta
423
+ self.scale = scale
405
424
  freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
406
425
 
407
426
  flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
@@ -531,8 +550,8 @@ class Sam3ViTPatchEmbeddings(nn.Module):
531
550
  image_size, patch_size = config.pretrain_image_size, config.patch_size
532
551
  num_channels, hidden_size = config.num_channels, config.hidden_size
533
552
 
534
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
535
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
553
+ image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
554
+ patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
536
555
  num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
537
556
  self.image_size = image_size
538
557
  self.patch_size = patch_size
@@ -542,7 +561,7 @@ class Sam3ViTPatchEmbeddings(nn.Module):
542
561
  self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
543
562
 
544
563
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
545
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
564
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
546
565
  return embeddings
547
566
 
548
567
 
@@ -761,6 +780,19 @@ class Sam3PreTrainedModel(PreTrainedModel):
761
780
  super()._init_weights(module)
762
781
  if isinstance(module, Sam3ViTEmbeddings):
763
782
  init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
783
+ elif isinstance(module, Sam3ViTRotaryEmbedding):
784
+ end_x, end_y = module.end_x, module.end_y
785
+ dim = module.dim
786
+ freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
787
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
788
+ x_positions = (flattened_indices % end_x) * module.scale
789
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
790
+ freqs_x = torch.outer(x_positions, freqs).float()
791
+ freqs_y = torch.outer(y_positions, freqs).float()
792
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
793
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
794
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
795
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
764
796
 
765
797
 
766
798
  @auto_docstring
@@ -938,6 +970,7 @@ class Sam3FPNLayer(nn.Module):
938
970
  self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
939
971
 
940
972
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
973
+ hidden_states = hidden_states.to(self.proj1.weight.dtype)
941
974
  for layer in self.scale_layers:
942
975
  hidden_states = layer(hidden_states)
943
976
 
@@ -1253,7 +1286,7 @@ class Sam3DetrEncoderLayer(nn.Module):
1253
1286
  vision_feats: Tensor,
1254
1287
  prompt_feats: Tensor,
1255
1288
  vision_pos_encoding: Tensor,
1256
- prompt_mask: Tensor,
1289
+ prompt_cross_attn_mask: Optional[Tensor] = None,
1257
1290
  **kwargs: Unpack[TransformersKwargs],
1258
1291
  ):
1259
1292
  """
@@ -1263,7 +1296,7 @@ class Sam3DetrEncoderLayer(nn.Module):
1263
1296
  vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
1264
1297
  prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
1265
1298
  vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
1266
- prompt_mask: Padding mask for prompts [batch_size, text_len] where True=valid, False=padding
1299
+ prompt_cross_attn_mask: Cross-attention mask for prompt features
1267
1300
 
1268
1301
  Returns:
1269
1302
  Updated vision features [batch_size, vision_len, hidden_size]
@@ -1284,15 +1317,6 @@ class Sam3DetrEncoderLayer(nn.Module):
1284
1317
  residual = hidden_states
1285
1318
  hidden_states = self.layer_norm2(hidden_states)
1286
1319
 
1287
- prompt_cross_attn_mask = None
1288
- if prompt_mask is not None:
1289
- prompt_cross_attn_mask = create_bidirectional_mask(
1290
- config=self.config,
1291
- input_embeds=hidden_states,
1292
- attention_mask=prompt_mask,
1293
- encoder_hidden_states=prompt_feats,
1294
- )
1295
-
1296
1320
  hidden_states, _ = self.cross_attn(
1297
1321
  query=hidden_states,
1298
1322
  key=prompt_feats,
@@ -1331,6 +1355,8 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
1331
1355
 
1332
1356
  self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
1333
1357
 
1358
+ self.post_init()
1359
+
1334
1360
  def _prepare_multilevel_features(
1335
1361
  self,
1336
1362
  vision_features: list[torch.Tensor],
@@ -1412,13 +1438,22 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
1412
1438
  spatial_shapes,
1413
1439
  ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
1414
1440
 
1441
+ prompt_cross_attn_mask = None
1442
+ if text_mask is not None:
1443
+ prompt_cross_attn_mask = create_bidirectional_mask(
1444
+ config=self.config,
1445
+ input_embeds=features_flattened,
1446
+ attention_mask=text_mask,
1447
+ encoder_hidden_states=text_features,
1448
+ )
1449
+
1415
1450
  hidden_states = features_flattened
1416
1451
  for layer in self.layers:
1417
1452
  hidden_states = layer(
1418
1453
  hidden_states,
1419
1454
  prompt_feats=text_features,
1420
1455
  vision_pos_encoding=pos_embeds_flattened,
1421
- prompt_mask=text_mask,
1456
+ prompt_cross_attn_mask=prompt_cross_attn_mask,
1422
1457
  **kwargs,
1423
1458
  )
1424
1459
  return Sam3DETREncoderOutput(
@@ -1484,31 +1519,27 @@ class Sam3DetrDecoderLayer(nn.Module):
1484
1519
  text_features: torch.Tensor,
1485
1520
  vision_features: torch.Tensor,
1486
1521
  vision_pos_encoding: torch.Tensor,
1487
- text_mask: Optional[torch.Tensor] = None,
1522
+ text_cross_attn_mask: Optional[torch.Tensor] = None,
1488
1523
  vision_cross_attn_mask: Optional[torch.Tensor] = None,
1489
- presence_token: Optional[torch.Tensor] = None,
1490
1524
  **kwargs: Unpack[TransformersKwargs],
1491
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1525
+ ) -> torch.Tensor:
1492
1526
  """
1493
1527
  Forward pass for decoder layer.
1494
1528
 
1495
1529
  Args:
1496
- hidden_states: Query features [batch_size, num_queries, hidden_size]
1530
+ hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
1497
1531
  query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
1498
1532
  text_features: Text features [batch_size, seq_len, hidden_size]
1499
1533
  vision_features: Vision features [batch_size, height*width, hidden_size]
1500
1534
  vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
1501
- text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
1502
- vision_cross_attn_mask: Vision cross-attention mask [batch_size, num_heads, num_queries, height*width]
1503
- presence_token: Optional presence token [batch_size, 1, hidden_size]
1535
+ text_cross_attn_mask: Text cross-attention mask
1536
+ vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
1504
1537
 
1505
1538
  Returns:
1506
- Tuple of (updated hidden states, updated presence token)
1539
+ Updated hidden states (including presence token at position 0)
1507
1540
  """
1508
- # Concatenate presence token if provided
1509
- if presence_token is not None:
1510
- hidden_states = torch.cat([presence_token, hidden_states], dim=1)
1511
- query_pos = torch.cat([torch.zeros_like(presence_token), query_pos], dim=1)
1541
+ # Prepend zeros to query_pos for presence token
1542
+ query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
1512
1543
 
1513
1544
  # Self-attention with query position encoding
1514
1545
  residual = hidden_states
@@ -1527,15 +1558,6 @@ class Sam3DetrDecoderLayer(nn.Module):
1527
1558
  residual = hidden_states
1528
1559
  query_with_pos = hidden_states + query_pos
1529
1560
 
1530
- text_cross_attn_mask = None
1531
- if text_mask is not None:
1532
- text_cross_attn_mask = create_bidirectional_mask(
1533
- config=self.config,
1534
- input_embeds=hidden_states,
1535
- attention_mask=text_mask,
1536
- encoder_hidden_states=text_features,
1537
- )
1538
-
1539
1561
  attn_output, _ = self.text_cross_attn(
1540
1562
  query=query_with_pos,
1541
1563
  key=text_features,
@@ -1546,20 +1568,6 @@ class Sam3DetrDecoderLayer(nn.Module):
1546
1568
  hidden_states = residual + self.text_cross_attn_dropout(attn_output)
1547
1569
  hidden_states = self.text_cross_attn_layer_norm(hidden_states)
1548
1570
 
1549
- # Expand vision cross-attention mask for presence token if needed
1550
- combined_vision_mask = vision_cross_attn_mask
1551
- if presence_token is not None and combined_vision_mask is not None:
1552
- batch_size, num_heads = combined_vision_mask.shape[:2]
1553
- presence_mask = torch.zeros(
1554
- batch_size,
1555
- num_heads,
1556
- 1,
1557
- combined_vision_mask.shape[-1],
1558
- device=combined_vision_mask.device,
1559
- dtype=combined_vision_mask.dtype,
1560
- )
1561
- combined_vision_mask = torch.cat([presence_mask, combined_vision_mask], dim=2)
1562
-
1563
1571
  # Vision cross-attention: queries attend to vision features (with RPB)
1564
1572
  residual = hidden_states
1565
1573
  query_with_pos = hidden_states + query_pos
@@ -1568,7 +1576,7 @@ class Sam3DetrDecoderLayer(nn.Module):
1568
1576
  query=query_with_pos,
1569
1577
  key=key_with_pos,
1570
1578
  value=vision_features,
1571
- attention_mask=combined_vision_mask,
1579
+ attention_mask=vision_cross_attn_mask,
1572
1580
  **kwargs,
1573
1581
  )
1574
1582
  hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
@@ -1580,13 +1588,7 @@ class Sam3DetrDecoderLayer(nn.Module):
1580
1588
  hidden_states = residual + self.mlp_dropout(hidden_states)
1581
1589
  hidden_states = self.mlp_layer_norm(hidden_states)
1582
1590
 
1583
- # Extract presence token if it was added
1584
- presence_token_out = None
1585
- if presence_token is not None:
1586
- presence_token_out = hidden_states[:, :1]
1587
- hidden_states = hidden_states[:, 1:]
1588
-
1589
- return hidden_states, presence_token_out
1591
+ return hidden_states
1590
1592
 
1591
1593
 
1592
1594
  class Sam3DetrDecoder(Sam3PreTrainedModel):
@@ -1634,6 +1636,8 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1634
1636
 
1635
1637
  self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
1636
1638
 
1639
+ self.post_init()
1640
+
1637
1641
  @compile_compatible_method_lru_cache(maxsize=1)
1638
1642
  def _get_coords(
1639
1643
  self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
@@ -1715,11 +1719,23 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1715
1719
  """
1716
1720
  batch_size = vision_features.shape[0]
1717
1721
 
1718
- hidden_states = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
1722
+ query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
1719
1723
  reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
1720
1724
  reference_boxes = reference_boxes.sigmoid()
1721
1725
  presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
1722
1726
 
1727
+ # Concatenate presence token with query embeddings
1728
+ hidden_states = torch.cat([presence_token, query_embeds], dim=1)
1729
+
1730
+ text_cross_attn_mask = None
1731
+ if text_mask is not None:
1732
+ text_cross_attn_mask = create_bidirectional_mask(
1733
+ config=self.config,
1734
+ input_embeds=hidden_states,
1735
+ attention_mask=text_mask,
1736
+ encoder_hidden_states=text_features,
1737
+ )
1738
+
1723
1739
  intermediate_outputs = []
1724
1740
  intermediate_boxes = [reference_boxes]
1725
1741
  intermediate_presence_logits = []
@@ -1734,43 +1750,45 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1734
1750
  vision_cross_attn_mask = None
1735
1751
  if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
1736
1752
  spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
1737
- vision_cross_attn_mask = self._get_rpb_matrix(reference_boxes, spatial_shape)
1753
+ rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
1754
+ # Prepend zeros row for presence token (it attends to all vision tokens equally)
1755
+ vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
1738
1756
 
1739
- hidden_states, presence_token = layer(
1757
+ hidden_states = layer(
1740
1758
  hidden_states,
1741
1759
  query_pos=query_pos,
1742
1760
  text_features=text_features,
1743
1761
  vision_features=vision_features,
1744
1762
  vision_pos_encoding=vision_pos_encoding,
1745
- text_mask=text_mask,
1763
+ text_cross_attn_mask=text_cross_attn_mask,
1746
1764
  vision_cross_attn_mask=vision_cross_attn_mask,
1747
- presence_token=presence_token,
1748
1765
  **kwargs,
1749
1766
  )
1750
1767
 
1768
+ # Extract query hidden states (without presence token) for box refinement
1769
+ query_hidden_states = hidden_states[:, 1:]
1770
+
1751
1771
  # Box refinement: predict delta and update reference boxes
1752
1772
  reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
1753
- delta_boxes = self.box_head(self.output_layer_norm(hidden_states))
1773
+ delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
1754
1774
  new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
1755
1775
  reference_boxes = new_reference_boxes.detach()
1756
1776
 
1757
- intermediate_outputs.append(self.output_layer_norm(hidden_states))
1777
+ intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
1758
1778
  intermediate_boxes.append(new_reference_boxes)
1759
1779
 
1760
1780
  # Process presence token
1761
- if presence_token is not None:
1762
- presence_logits = self.presence_head(self.presence_layer_norm(presence_token)).squeeze(-1)
1763
- presence_logits = presence_logits.clamp(
1764
- min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
1765
- )
1766
- intermediate_presence_logits.append(presence_logits)
1781
+ presence_hidden = hidden_states[:, :1]
1782
+ presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
1783
+ presence_logits = presence_logits.clamp(
1784
+ min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
1785
+ )
1786
+ intermediate_presence_logits.append(presence_logits)
1767
1787
 
1768
1788
  # Stack outputs from all layers
1769
1789
  intermediate_outputs = torch.stack(intermediate_outputs)
1770
1790
  intermediate_boxes = torch.stack(intermediate_boxes[:-1])
1771
- intermediate_presence_logits = (
1772
- torch.stack(intermediate_presence_logits) if intermediate_presence_logits else None
1773
- )
1791
+ intermediate_presence_logits = torch.stack(intermediate_presence_logits)
1774
1792
 
1775
1793
  return Sam3DETRDecoderOutput(
1776
1794
  intermediate_hidden_states=intermediate_outputs,
@@ -1990,6 +2008,8 @@ class Sam3MaskDecoder(Sam3PreTrainedModel):
1990
2008
  self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
1991
2009
  self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
1992
2010
 
2011
+ self.post_init()
2012
+
1993
2013
  @check_model_inputs
1994
2014
  def forward(
1995
2015
  self,
@@ -107,7 +107,12 @@ class Sam3TrackerFeedForward(nn.Module):
107
107
  return hidden_states
108
108
 
109
109
 
110
- @auto_docstring
110
+ @auto_docstring(
111
+ custom_intro="""
112
+ Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
113
+ input points and labels, boxes, or masks.
114
+ """
115
+ )
111
116
  class Sam3TrackerPreTrainedModel(PreTrainedModel):
112
117
  config_class = Sam3TrackerConfig
113
118
  base_model_prefix = "sam3_tracker"
@@ -123,6 +128,8 @@ class Sam3TrackerPreTrainedModel(PreTrainedModel):
123
128
  if isinstance(module, Sam3TrackerModel):
124
129
  if module.no_memory_embedding is not None:
125
130
  init.zeros_(module.no_memory_embedding)
131
+ elif isinstance(module, Sam3TrackerPositionalEmbedding):
132
+ init.normal_(module.positional_embedding, std=module.scale)
126
133
 
127
134
 
128
135
  class Sam3TrackerPositionalEmbedding(nn.Module):
@@ -136,7 +136,12 @@ class Sam3TrackerFeedForward(Sam2FeedForward):
136
136
  pass
137
137
 
138
138
 
139
- @auto_docstring
139
+ @auto_docstring(
140
+ custom_intro="""
141
+ Segment Anything Model 3 (SAM 3) for generating segmentation masks, given an input image and
142
+ input points and labels, boxes, or masks.
143
+ """
144
+ )
140
145
  class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
141
146
  @torch.no_grad()
142
147
  def _init_weights(self, module):
@@ -144,6 +149,8 @@ class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
144
149
  if isinstance(module, Sam3TrackerModel):
145
150
  if module.no_memory_embedding is not None:
146
151
  init.zeros_(module.no_memory_embedding)
152
+ elif isinstance(module, Sam3TrackerPositionalEmbedding):
153
+ init.normal_(module.positional_embedding, std=module.scale)
147
154
 
148
155
 
149
156
  class Sam3TrackerPositionalEmbedding(Sam2PositionalEmbedding):
@@ -397,5 +397,30 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
397
397
 
398
398
  super().__init__(**kwargs)
399
399
 
400
+ @property
401
+ def image_size(self):
402
+ """Image size for the tracker video model."""
403
+ return self.vision_config.image_size
404
+
405
+ @image_size.setter
406
+ def image_size(self, value):
407
+ """Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
408
+ self.prompt_encoder_config.image_size = value
409
+ self.vision_config.image_size = value
410
+
411
+ patch_size = self.vision_config.backbone_config.patch_size
412
+ self.vision_config.backbone_feature_sizes = [
413
+ [4 * value // patch_size, 4 * value // patch_size],
414
+ [2 * value // patch_size, 2 * value // patch_size],
415
+ [value // patch_size, value // patch_size],
416
+ ]
417
+ self.memory_attention_rope_feat_sizes = [
418
+ value // patch_size,
419
+ value // patch_size,
420
+ ]
421
+
422
+ # keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
423
+ self.__dict__["image_size"] = value
424
+
400
425
 
401
426
  __all__ = ["Sam3TrackerVideoMaskDecoderConfig", "Sam3TrackerVideoPromptEncoderConfig", "Sam3TrackerVideoConfig"]
@@ -213,7 +213,7 @@ class Sam3TrackerVideoInferenceSession:
213
213
  device_inputs = {}
214
214
  for key, value in inputs.items():
215
215
  if isinstance(value, torch.Tensor):
216
- device_inputs[key] = value.to(self.inference_device, non_blocking=True)
216
+ device_inputs[key] = value.to(self.inference_device, non_blocking=False)
217
217
  else:
218
218
  device_inputs[key] = value
219
219
  self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
@@ -692,6 +692,12 @@ class Sam3TrackerVideoPreTrainedModel(PreTrainedModel):
692
692
  if isinstance(module, Sam3TrackerVideoMemoryFuserCXBlock):
693
693
  if module.scale is not None:
694
694
  init.zeros_(module.scale)
695
+ elif isinstance(module, Sam3TrackerVideoVisionRotaryEmbedding):
696
+ inv_freq = module.create_inv_freq()
697
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
698
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
699
+ elif isinstance(module, Sam3TrackerVideoPositionalEmbedding):
700
+ init.normal_(module.positional_embedding, std=module.scale)
695
701
 
696
702
 
697
703
  class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
@@ -702,24 +708,17 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
702
708
 
703
709
  def __init__(self, config: Sam3TrackerVideoConfig):
704
710
  super().__init__()
705
- dim = config.memory_attention_hidden_size // (
711
+ self.dim = config.memory_attention_hidden_size // (
706
712
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
707
713
  )
708
714
  # Ensure even dimension for proper axial splitting
709
- if dim % 4 != 0:
715
+ if self.dim % 4 != 0:
710
716
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
711
- end_x, end_y = config.memory_attention_rope_feat_sizes
712
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
717
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
718
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
713
719
 
714
- # Generate 2D position indices for axial rotary embedding
715
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
716
- x_positions = flattened_indices % end_x
717
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
718
- freqs_x = torch.outer(x_positions, freqs).float()
719
- freqs_y = torch.outer(y_positions, freqs).float()
720
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
721
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
722
720
  # directly register the cos and sin embeddings as we have a fixed feature shape
721
+ inv_freq = self.create_inv_freq()
723
722
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
724
723
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
725
724
 
@@ -728,6 +727,20 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
728
727
  # As the feature map size is fixed, we can just return the pre-computed embeddings.
729
728
  return self.rope_embeddings_cos, self.rope_embeddings_sin
730
729
 
730
+ def create_inv_freq(self):
731
+ freqs = 1.0 / (
732
+ self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
733
+ )
734
+ # Generate 2D position indices for axial rotary embedding
735
+ flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
736
+ x_positions = flattened_indices % self.end_x
737
+ y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
738
+ freqs_x = torch.outer(x_positions, freqs).float()
739
+ freqs_y = torch.outer(y_positions, freqs).float()
740
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
741
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
742
+ return inv_freq
743
+
731
744
 
732
745
  def rotate_pairwise(x):
733
746
  """
@@ -1567,8 +1580,6 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
1567
1580
  input_modalities = ("video", "text")
1568
1581
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam3TrackerVideoTwoWayAttentionBlock, index=2)}
1569
1582
  _keys_to_ignore_on_load_unexpected = [r"^detector_model."]
1570
- _tied_weights_keys = {}
1571
- _keys_to_ignore_on_load_missing = []
1572
1583
  _checkpoint_conversion_mapping = {
1573
1584
  r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
1574
1585
  "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
@@ -1719,6 +1730,7 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
1719
1730
  frame: Optional[torch.Tensor] = None,
1720
1731
  reverse: bool = False,
1721
1732
  run_mem_encoder: bool = True,
1733
+ **kwargs,
1722
1734
  ) -> Sam3TrackerVideoSegmentationOutput:
1723
1735
  r"""
1724
1736
  inference_session (`Sam3TrackerVideoInferenceSession`):
@@ -353,6 +353,31 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
353
353
 
354
354
  super().__init__(**kwargs)
355
355
 
356
+ @property
357
+ def image_size(self):
358
+ """Image size for the tracker video model."""
359
+ return self.vision_config.image_size
360
+
361
+ @image_size.setter
362
+ def image_size(self, value):
363
+ """Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
364
+ self.prompt_encoder_config.image_size = value
365
+ self.vision_config.image_size = value
366
+
367
+ patch_size = self.vision_config.backbone_config.patch_size
368
+ self.vision_config.backbone_feature_sizes = [
369
+ [4 * value // patch_size, 4 * value // patch_size],
370
+ [2 * value // patch_size, 2 * value // patch_size],
371
+ [value // patch_size, value // patch_size],
372
+ ]
373
+ self.memory_attention_rope_feat_sizes = [
374
+ value // patch_size,
375
+ value // patch_size,
376
+ ]
377
+
378
+ # keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
379
+ self.__dict__["image_size"] = value
380
+
356
381
 
357
382
  class Sam3TrackerVideoInferenceCache(Sam2VideoInferenceCache):
358
383
  pass
@@ -461,8 +486,6 @@ class Sam3TrackerVideoModel(Sam2VideoModel):
461
486
  "tracker_neck.": "vision_encoder.neck.",
462
487
  }
463
488
  _keys_to_ignore_on_load_unexpected = [r"^detector_model."]
464
- _tied_weights_keys = {}
465
- _keys_to_ignore_on_load_missing = []
466
489
 
467
490
  def __init__(self, config: Sam3TrackerVideoConfig, remove_vision_encoder: bool = False):
468
491
  r"""
@@ -96,6 +96,9 @@ class Sam3VideoConfig(PreTrainedConfig):
96
96
  >>> # Initializing a SAM3 Video configuration with default detector and tracker
97
97
  >>> configuration = Sam3VideoConfig()
98
98
 
99
+ >>> # Changing image size for custom resolution inference (automatically propagates to all nested configs)
100
+ >>> configuration.image_size = 560
101
+
99
102
  >>> # Initializing a model from the configuration
100
103
  >>> model = Sam3VideoModel(configuration)
101
104
 
@@ -225,5 +228,16 @@ class Sam3VideoConfig(PreTrainedConfig):
225
228
  self.high_conf_thresh = high_conf_thresh
226
229
  self.high_iou_thresh = high_iou_thresh
227
230
 
231
+ @property
232
+ def image_size(self):
233
+ """Image size for the video model."""
234
+ return self.detector_config.image_size
235
+
236
+ @image_size.setter
237
+ def image_size(self, value):
238
+ """Recursively propagate the image size to detector and tracker configs."""
239
+ self.detector_config.image_size = value
240
+ self.tracker_config.image_size = value
241
+
228
242
 
229
243
  __all__ = ["Sam3VideoConfig"]
@@ -33,7 +33,7 @@ from .configuration_sam3_video import Sam3VideoConfig
33
33
 
34
34
 
35
35
  if is_kernels_available():
36
- from kernels import get_kernel
36
+ from ...integrations.hub_kernels import get_kernel
37
37
 
38
38
  logger = logging.get_logger(__name__)
39
39
 
@@ -505,8 +505,6 @@ class Sam3VideoPreTrainedModel(PreTrainedModel):
505
505
 
506
506
  @auto_docstring
507
507
  class Sam3VideoModel(Sam3VideoPreTrainedModel):
508
- all_tied_weights_keys = {}
509
-
510
508
  def __init__(self, config: Sam3VideoConfig):
511
509
  super().__init__(config)
512
510
  self.config = config
@@ -542,6 +540,8 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
542
540
 
543
541
  self.tracker_neck = Sam3VisionNeck(config.detector_config.vision_config)
544
542
 
543
+ self.post_init()
544
+
545
545
  def get_vision_features_for_tracker(self, vision_embeds: torch.Tensor):
546
546
  hidden_states = vision_embeds.last_hidden_state
547
547
  batch_size = hidden_states.shape[0]
@@ -1697,6 +1697,7 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
1697
1697
  frame_idx: Optional[int] = None,
1698
1698
  frame: Optional[torch.Tensor] = None,
1699
1699
  reverse: bool = False,
1700
+ **kwargs,
1700
1701
  ):
1701
1702
  r"""
1702
1703
  inference_session (`Sam3VideoInferenceSession`):