transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -262,9 +262,14 @@ class ResNetPreTrainedModel(PreTrainedModel):
262
262
  fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
263
263
  bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
264
264
  init.uniform_(module.bias, -bound, bound)
265
- elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
266
- init.constant_(module.weight, 1)
267
- init.constant_(module.bias, 0)
265
+ # We need to check it like that as some Detr models replace the BatchNorm2d by their own
266
+ elif "BatchNorm" in module.__class__.__name__:
267
+ init.ones_(module.weight)
268
+ init.zeros_(module.bias)
269
+ init.zeros_(module.running_mean)
270
+ init.ones_(module.running_var)
271
+ if getattr(module, "num_batches_tracked", None) is not None:
272
+ init.zeros_(module.num_batches_tracked)
268
273
 
269
274
 
270
275
  @auto_docstring
@@ -501,6 +501,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
501
501
  super()._init_weights(module)
502
502
  if isinstance(module, RobertaLMHead):
503
503
  init.zeros_(module.bias)
504
+ elif isinstance(module, RobertaEmbeddings):
505
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
506
+ init.zeros_(module.token_type_ids)
504
507
 
505
508
 
506
509
  class RobertaEncoder(nn.Module):
@@ -172,6 +172,9 @@ class RobertaPreTrainedModel(PreTrainedModel):
172
172
  super()._init_weights(module)
173
173
  if isinstance(module, RobertaLMHead):
174
174
  init.zeros_(module.bias)
175
+ elif isinstance(module, RobertaEmbeddings):
176
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
177
+ init.zeros_(module.token_type_ids)
175
178
 
176
179
 
177
180
  class RobertaModel(BertModel):
@@ -561,6 +561,9 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel):
561
561
  super()._init_weights(module)
562
562
  if isinstance(module, RobertaPreLayerNormLMHead):
563
563
  init.zeros_(module.bias)
564
+ elif isinstance(module, RobertaPreLayerNormEmbeddings):
565
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
566
+ init.zeros_(module.token_type_ids)
564
567
 
565
568
 
566
569
  @auto_docstring(
@@ -621,6 +621,9 @@ class RoCBertPreTrainedModel(PreTrainedModel):
621
621
  super()._init_weights(module)
622
622
  if isinstance(module, RoCBertLMPredictionHead):
623
623
  init.zeros_(module.bias)
624
+ elif isinstance(module, RoCBertEmbeddings):
625
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
626
+ init.zeros_(module.token_type_ids)
624
627
 
625
628
 
626
629
  @auto_docstring(
@@ -44,7 +44,7 @@ class RTDetrConfig(PreTrainedConfig):
44
44
  The epsilon used by the layer normalization layers.
45
45
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
46
46
  The epsilon used by the batch normalization layers.
47
- backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
47
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrResNetConfig()`):
48
48
  The configuration of the backbone model.
49
49
  backbone (`str`, *optional*):
50
50
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -1059,6 +1059,10 @@ class RTDetrPreTrainedModel(PreTrainedModel):
1059
1059
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1060
1060
  if module.bias is not None:
1061
1061
  init.zeros_(module.bias)
1062
+ if getattr(module, "running_mean", None) is not None:
1063
+ init.zeros_(module.running_mean)
1064
+ init.ones_(module.running_var)
1065
+ init.zeros_(module.num_batches_tracked)
1062
1066
 
1063
1067
  elif isinstance(module, nn.LayerNorm):
1064
1068
  init.ones_(module.weight)
@@ -316,9 +316,14 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel):
316
316
  fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
317
317
  bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
318
318
  init.uniform_(module.bias, -bound, bound)
319
- elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
320
- init.constant_(module.weight, 1)
321
- init.constant_(module.bias, 0)
319
+ # We need to check it like that as some Detr models replace the BatchNorm2d by their own
320
+ elif "BatchNorm" in module.__class__.__name__:
321
+ init.ones_(module.weight)
322
+ init.zeros_(module.bias)
323
+ init.zeros_(module.running_mean)
324
+ init.ones_(module.running_var)
325
+ if getattr(module, "num_batches_tracked", None) is not None:
326
+ init.zeros_(module.num_batches_tracked)
322
327
 
323
328
 
324
329
  @auto_docstring(
@@ -18,7 +18,6 @@
18
18
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
19
  # See the License for the specific language governing permissions and
20
20
  # limitations under the License.
21
-
22
21
  from ...configuration_utils import PreTrainedConfig
23
22
  from ...utils import logging
24
23
  from ...utils.backbone_utils import verify_backbone_config_arguments
@@ -49,7 +48,7 @@ class RTDetrV2Config(PreTrainedConfig):
49
48
  The epsilon used by the layer normalization layers.
50
49
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
51
50
  The epsilon used by the batch normalization layers.
52
- backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
51
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
53
52
  The configuration of the backbone model.
54
53
  backbone (`str`, *optional*):
55
54
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -357,8 +356,8 @@ class RTDetrV2Config(PreTrainedConfig):
357
356
  self.decoder_n_levels = decoder_n_levels
358
357
  self.decoder_offset_scale = decoder_offset_scale
359
358
  self.decoder_method = decoder_method
359
+
360
360
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
361
- self.tie_encoder_decoder = True
362
361
 
363
362
 
364
363
  __all__ = ["RTDetrV2Config"]
@@ -506,6 +506,10 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
506
506
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
507
507
  if module.bias is not None:
508
508
  init.zeros_(module.bias)
509
+ if getattr(module, "running_mean", None) is not None:
510
+ init.zeros_(module.running_mean)
511
+ init.ones_(module.running_var)
512
+ init.zeros_(module.num_batches_tracked)
509
513
 
510
514
  elif isinstance(module, nn.LayerNorm):
511
515
  init.ones_(module.weight)
@@ -515,6 +519,9 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
515
519
  init.xavier_uniform_(module.weight_embedding.weight)
516
520
  if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
517
521
  init.xavier_uniform_(module.denoising_class_embed.weight)
522
+ if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
523
+ n_points_scale = [1 / n for n in module.n_points_list for _ in range(n)]
524
+ init.copy_(module.n_points_scale, torch.tensor(n_points_scale, dtype=torch.float32))
518
525
 
519
526
 
520
527
  @dataclass
@@ -19,6 +19,7 @@ import torch
19
19
  import torch.nn.functional as F
20
20
  from torch import Tensor, nn
21
21
 
22
+ from ... import initialization as init
22
23
  from ...configuration_utils import PreTrainedConfig
23
24
  from ...utils import is_torchdynamo_compiling, logging
24
25
  from ...utils.backbone_utils import (
@@ -59,7 +60,7 @@ class RTDetrV2Config(PreTrainedConfig):
59
60
  The epsilon used by the layer normalization layers.
60
61
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
61
62
  The epsilon used by the batch normalization layers.
62
- backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
63
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
63
64
  The configuration of the backbone model.
64
65
  backbone (`str`, *optional*):
65
66
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -367,8 +368,8 @@ class RTDetrV2Config(PreTrainedConfig):
367
368
  self.decoder_n_levels = decoder_n_levels
368
369
  self.decoder_offset_scale = decoder_offset_scale
369
370
  self.decoder_method = decoder_method
371
+
370
372
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
371
- self.tie_encoder_decoder = True
372
373
 
373
374
 
374
375
  def multi_scale_deformable_attention_v2(
@@ -564,7 +565,11 @@ class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
564
565
 
565
566
 
566
567
  class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
567
- pass
568
+ def _init_weights(self, module):
569
+ super()._init_weights(module)
570
+ if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
571
+ n_points_scale = [1 / n for n in module.n_points_list for _ in range(n)]
572
+ init.copy_(module.n_points_scale, torch.tensor(n_points_scale, dtype=torch.float32))
568
573
 
569
574
 
570
575
  class RTDetrV2Decoder(RTDetrDecoder):
@@ -49,7 +49,7 @@ def load_wkv_cuda_kernel(context_length):
49
49
  if not is_kernels_available():
50
50
  raise ImportError("kernels is not installed, please install it with `pip install kernels`")
51
51
 
52
- from kernels import get_kernel
52
+ from ...integrations.hub_kernels import get_kernel
53
53
 
54
54
  rwkv_cuda_kernel = get_kernel("kernels-community/rwkv")
55
55
  rwkv_cuda_kernel.max_seq_length = context_length
@@ -249,6 +249,7 @@ class SamVisionConfig(PreTrainedConfig):
249
249
  self.global_attn_indexes = global_attn_indexes
250
250
  self.num_pos_feats = num_pos_feats
251
251
  self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
252
+ self.scale = self.hidden_size // 2
252
253
 
253
254
 
254
255
  class SamConfig(PreTrainedConfig):
@@ -267,7 +267,6 @@ class SamImageProcessorFast(BaseImageProcessorFast):
267
267
  if do_pad:
268
268
  processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
269
269
 
270
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
271
270
  return BatchFeature(
272
271
  data={"pixel_values": processed_images, "reshaped_input_sizes": reshaped_input_sizes},
273
272
  tensor_type=return_tensors,
@@ -548,7 +548,7 @@ class SamMaskDecoder(nn.Module):
548
548
  class SamPositionalEmbedding(nn.Module):
549
549
  def __init__(self, config):
550
550
  super().__init__()
551
- self.scale = config.hidden_size // 2
551
+ self.scale = config.scale
552
552
  self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
553
553
 
554
554
  def forward(self, input_coords, input_shape=None):
@@ -1014,6 +1014,8 @@ class SamPreTrainedModel(PreTrainedModel):
1014
1014
  elif isinstance(module, SamVisionEncoder):
1015
1015
  if self.config.use_abs_pos:
1016
1016
  init.zeros_(module.pos_embed)
1017
+ elif isinstance(module, SamPositionalEmbedding):
1018
+ init.normal_(module.positional_embedding, std=module.scale)
1017
1019
 
1018
1020
 
1019
1021
  class SamVisionEncoder(SamPreTrainedModel):
@@ -1048,6 +1050,7 @@ class SamVisionEncoder(SamPreTrainedModel):
1048
1050
  self.neck = SamVisionNeck(config)
1049
1051
 
1050
1052
  self.gradient_checkpointing = False
1053
+ self.post_init()
1051
1054
 
1052
1055
  def get_input_embeddings(self):
1053
1056
  return self.patch_embed
@@ -152,7 +152,7 @@ class Sam2VisionConfig(PreTrainedConfig):
152
152
  documentation from [`PreTrainedConfig`] for more information.
153
153
 
154
154
  Args:
155
- backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*):
155
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `Sam2HieraDetConfig()`):
156
156
  Configuration for the vision backbone. This is used to instantiate the backbone using
157
157
  `AutoModel.from_config`.
158
158
  backbone_channel_list (`List[int]`, *optional*, defaults to `[768, 384, 192, 96]`):
@@ -565,7 +565,9 @@ class Sam2PreTrainedModel(PreTrainedModel):
565
565
  init.zeros_(module.pos_embed)
566
566
  if module.pos_embed_window is not None:
567
567
  init.zeros_(module.pos_embed_window)
568
- if isinstance(module, Sam2Model):
568
+ elif isinstance(module, Sam2PositionalEmbedding):
569
+ init.normal_(module.positional_embedding, std=module.scale)
570
+ elif isinstance(module, Sam2Model):
569
571
  if module.no_memory_embedding is not None:
570
572
  init.zeros_(module.no_memory_embedding)
571
573
 
@@ -600,6 +602,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
600
602
  self.blocks.append(block)
601
603
  total_block_idx += 1
602
604
 
605
+ self.post_init()
606
+
603
607
  def get_input_embeddings(self):
604
608
  return self.patch_embed
605
609
 
@@ -681,7 +681,9 @@ class Sam2PreTrainedModel(PreTrainedModel):
681
681
  init.zeros_(module.pos_embed)
682
682
  if module.pos_embed_window is not None:
683
683
  init.zeros_(module.pos_embed_window)
684
- if isinstance(module, Sam2Model):
684
+ elif isinstance(module, Sam2PositionalEmbedding):
685
+ init.normal_(module.positional_embedding, std=module.scale)
686
+ elif isinstance(module, Sam2Model):
685
687
  if module.no_memory_embedding is not None:
686
688
  init.zeros_(module.no_memory_embedding)
687
689
 
@@ -716,6 +718,8 @@ class Sam2HieraDetModel(Sam2PreTrainedModel):
716
718
  self.blocks.append(block)
717
719
  total_block_idx += 1
718
720
 
721
+ self.post_init()
722
+
719
723
  def get_input_embeddings(self):
720
724
  return self.patch_embed
721
725
 
@@ -209,7 +209,7 @@ class Sam2VideoInferenceSession:
209
209
  device_inputs = {}
210
210
  for key, value in inputs.items():
211
211
  if isinstance(value, torch.Tensor):
212
- device_inputs[key] = value.to(self.inference_device, non_blocking=True)
212
+ device_inputs[key] = value.to(self.inference_device, non_blocking=False)
213
213
  else:
214
214
  device_inputs[key] = value
215
215
  self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
@@ -688,6 +688,12 @@ class Sam2VideoPreTrainedModel(PreTrainedModel):
688
688
  if isinstance(module, Sam2VideoMemoryFuserCXBlock):
689
689
  if module.scale is not None:
690
690
  init.zeros_(module.scale)
691
+ elif isinstance(module, Sam2VideoVisionRotaryEmbedding):
692
+ inv_freq = module.create_inv_freq()
693
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
694
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
695
+ elif isinstance(module, Sam2VideoPositionalEmbedding):
696
+ init.normal_(module.positional_embedding, std=module.scale)
691
697
 
692
698
 
693
699
  class Sam2VideoVisionRotaryEmbedding(nn.Module):
@@ -698,24 +704,17 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
698
704
 
699
705
  def __init__(self, config: Sam2VideoConfig):
700
706
  super().__init__()
701
- dim = config.memory_attention_hidden_size // (
707
+ self.dim = config.memory_attention_hidden_size // (
702
708
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
703
709
  )
704
710
  # Ensure even dimension for proper axial splitting
705
- if dim % 4 != 0:
711
+ if self.dim % 4 != 0:
706
712
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
707
- end_x, end_y = config.memory_attention_rope_feat_sizes
708
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
713
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
714
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
709
715
 
710
- # Generate 2D position indices for axial rotary embedding
711
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
712
- x_positions = flattened_indices % end_x
713
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
714
- freqs_x = torch.outer(x_positions, freqs).float()
715
- freqs_y = torch.outer(y_positions, freqs).float()
716
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
717
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
718
716
  # directly register the cos and sin embeddings as we have a fixed feature shape
717
+ inv_freq = self.create_inv_freq()
719
718
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
720
719
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
721
720
 
@@ -724,6 +723,20 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
724
723
  # As the feature map size is fixed, we can just return the pre-computed embeddings.
725
724
  return self.rope_embeddings_cos, self.rope_embeddings_sin
726
725
 
726
+ def create_inv_freq(self):
727
+ freqs = 1.0 / (
728
+ self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
729
+ )
730
+ # Generate 2D position indices for axial rotary embedding
731
+ flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
732
+ x_positions = flattened_indices % self.end_x
733
+ y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
734
+ freqs_x = torch.outer(x_positions, freqs).float()
735
+ freqs_y = torch.outer(y_positions, freqs).float()
736
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
737
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
738
+ return inv_freq
739
+
727
740
 
728
741
  def rotate_pairwise(x):
729
742
  """
@@ -1101,6 +1114,31 @@ class Sam2VideoMemoryEncoder(nn.Module):
1101
1114
  return vision_features, vision_pos_enc
1102
1115
 
1103
1116
 
1117
+ class Sam2VideoPositionalEmbedding(nn.Module):
1118
+ def __init__(self, config: Sam2VideoPromptEncoderConfig):
1119
+ super().__init__()
1120
+ self.scale = config.scale
1121
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
1122
+ self.register_buffer("positional_embedding", positional_embedding)
1123
+
1124
+ def forward(self, input_coords, input_shape=None):
1125
+ """Positionally encode points that are normalized to [0,1]."""
1126
+ coordinates = input_coords.clone()
1127
+
1128
+ if input_shape is not None:
1129
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
1130
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
1131
+ coordinates.to(torch.float32)
1132
+
1133
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1134
+ coordinates = 2 * coordinates - 1
1135
+ coordinates = coordinates.to(self.positional_embedding.dtype)
1136
+ coordinates = coordinates @ self.positional_embedding
1137
+ coordinates = 2 * np.pi * coordinates
1138
+ # outputs d_1 x ... x d_n x channel shape
1139
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
1140
+
1141
+
1104
1142
  @dataclass
1105
1143
  @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
1106
1144
  class Sam2VideoVisionEncoderOutput(ModelOutput):
@@ -1130,31 +1168,6 @@ class Sam2VideoVisionEncoderOutput(ModelOutput):
1130
1168
  attentions: Optional[tuple[torch.FloatTensor, ...]] = None
1131
1169
 
1132
1170
 
1133
- class Sam2VideoPositionalEmbedding(nn.Module):
1134
- def __init__(self, config: Sam2VideoPromptEncoderConfig):
1135
- super().__init__()
1136
- self.scale = config.scale
1137
- positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
1138
- self.register_buffer("positional_embedding", positional_embedding)
1139
-
1140
- def forward(self, input_coords, input_shape=None):
1141
- """Positionally encode points that are normalized to [0,1]."""
1142
- coordinates = input_coords.clone()
1143
-
1144
- if input_shape is not None:
1145
- coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
1146
- coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
1147
- coordinates.to(torch.float32)
1148
-
1149
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1150
- coordinates = 2 * coordinates - 1
1151
- coordinates = coordinates.to(self.positional_embedding.dtype)
1152
- coordinates = coordinates @ self.positional_embedding
1153
- coordinates = 2 * np.pi * coordinates
1154
- # outputs d_1 x ... x d_n x channel shape
1155
- return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
1156
-
1157
-
1158
1171
  class Sam2VideoMaskEmbedding(nn.Module):
1159
1172
  def __init__(self, config: Sam2VideoPromptEncoderConfig):
1160
1173
  super().__init__()
@@ -1559,11 +1572,6 @@ class Sam2VideoModel(Sam2VideoPreTrainedModel):
1559
1572
  input_modalities = ("video", "text")
1560
1573
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
1561
1574
  _keys_to_ignore_on_load_unexpected = []
1562
- _tied_weights_keys = {
1563
- "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
1564
- }
1565
- # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1566
- _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1567
1575
 
1568
1576
  def __init__(self, config: Sam2VideoConfig):
1569
1577
  super().__init__(config)
@@ -51,6 +51,7 @@ from ..sam2.modeling_sam2 import (
51
51
  Sam2ImageSegmentationOutput,
52
52
  Sam2LayerNorm,
53
53
  Sam2Model,
54
+ Sam2PositionalEmbedding,
54
55
  Sam2SinePositionEmbedding,
55
56
  Sam2TwoWayAttentionBlock,
56
57
  eager_attention_forward,
@@ -477,7 +478,7 @@ class Sam2VideoInferenceSession:
477
478
  device_inputs = {}
478
479
  for key, value in inputs.items():
479
480
  if isinstance(value, torch.Tensor):
480
- device_inputs[key] = value.to(self.inference_device, non_blocking=True)
481
+ device_inputs[key] = value.to(self.inference_device, non_blocking=False)
481
482
  else:
482
483
  device_inputs[key] = value
483
484
  self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
@@ -1013,6 +1014,12 @@ class Sam2VideoPreTrainedModel(PreTrainedModel):
1013
1014
  if isinstance(module, Sam2VideoMemoryFuserCXBlock):
1014
1015
  if module.scale is not None:
1015
1016
  init.zeros_(module.scale)
1017
+ elif isinstance(module, Sam2VideoVisionRotaryEmbedding):
1018
+ inv_freq = module.create_inv_freq()
1019
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
1020
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
1021
+ elif isinstance(module, Sam2VideoPositionalEmbedding):
1022
+ init.normal_(module.positional_embedding, std=module.scale)
1016
1023
 
1017
1024
 
1018
1025
  class Sam2VideoVisionRotaryEmbedding(nn.Module):
@@ -1023,24 +1030,17 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
1023
1030
 
1024
1031
  def __init__(self, config: Sam2VideoConfig):
1025
1032
  super().__init__()
1026
- dim = config.memory_attention_hidden_size // (
1033
+ self.dim = config.memory_attention_hidden_size // (
1027
1034
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
1028
1035
  )
1029
1036
  # Ensure even dimension for proper axial splitting
1030
- if dim % 4 != 0:
1037
+ if self.dim % 4 != 0:
1031
1038
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
1032
- end_x, end_y = config.memory_attention_rope_feat_sizes
1033
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
1039
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
1040
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
1034
1041
 
1035
- # Generate 2D position indices for axial rotary embedding
1036
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
1037
- x_positions = flattened_indices % end_x
1038
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
1039
- freqs_x = torch.outer(x_positions, freqs).float()
1040
- freqs_y = torch.outer(y_positions, freqs).float()
1041
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
1042
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
1043
1042
  # directly register the cos and sin embeddings as we have a fixed feature shape
1043
+ inv_freq = self.create_inv_freq()
1044
1044
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
1045
1045
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
1046
1046
 
@@ -1049,6 +1049,20 @@ class Sam2VideoVisionRotaryEmbedding(nn.Module):
1049
1049
  # As the feature map size is fixed, we can just return the pre-computed embeddings.
1050
1050
  return self.rope_embeddings_cos, self.rope_embeddings_sin
1051
1051
 
1052
+ def create_inv_freq(self):
1053
+ freqs = 1.0 / (
1054
+ self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
1055
+ )
1056
+ # Generate 2D position indices for axial rotary embedding
1057
+ flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
1058
+ x_positions = flattened_indices % self.end_x
1059
+ y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
1060
+ freqs_x = torch.outer(x_positions, freqs).float()
1061
+ freqs_y = torch.outer(y_positions, freqs).float()
1062
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
1063
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
1064
+ return inv_freq
1065
+
1052
1066
 
1053
1067
  def rotate_pairwise(x):
1054
1068
  """
@@ -1426,6 +1440,10 @@ class Sam2VideoMemoryEncoder(nn.Module):
1426
1440
  return vision_features, vision_pos_enc
1427
1441
 
1428
1442
 
1443
+ class Sam2VideoPositionalEmbedding(Sam2PositionalEmbedding):
1444
+ pass
1445
+
1446
+
1429
1447
  # a large negative value as a placeholder score for missing objects
1430
1448
  NO_OBJ_SCORE = -1024.0
1431
1449
 
@@ -1446,11 +1464,6 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
1446
1464
  @auto_docstring
1447
1465
  class Sam2VideoModel(Sam2Model):
1448
1466
  input_modalities = ("video", "text")
1449
- _tied_weights_keys = {
1450
- "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
1451
- }
1452
- # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1453
- _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1454
1467
  _keys_to_ignore_on_load_unexpected = []
1455
1468
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
1456
1469
 
@@ -122,7 +122,7 @@ class Sam3VisionConfig(PreTrainedConfig):
122
122
  documentation from [`PreTrainedConfig`] for more information.
123
123
 
124
124
  Args:
125
- backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*):
125
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `Sam3ViTConfig()`):
126
126
  Configuration for the vision backbone. This is used to instantiate the backbone using
127
127
  `AutoModel.from_config`.
128
128
  fpn_hidden_size (`int`, *optional*, defaults to 256):
@@ -179,6 +179,16 @@ class Sam3VisionConfig(PreTrainedConfig):
179
179
  self.initializer_range = initializer_range
180
180
  super().__init__(**kwargs)
181
181
 
182
+ @property
183
+ def image_size(self):
184
+ """Image size for the vision encoder."""
185
+ return self.backbone_config.image_size
186
+
187
+ @image_size.setter
188
+ def image_size(self, value):
189
+ """Set the image size and propagate to backbone."""
190
+ self.backbone_config.image_size = value
191
+
182
192
 
183
193
  class Sam3GeometryEncoderConfig(PreTrainedConfig):
184
194
  r"""
@@ -506,6 +516,16 @@ class Sam3Config(PreTrainedConfig):
506
516
  self.initializer_range = initializer_range
507
517
  super().__init__(**kwargs)
508
518
 
519
+ @property
520
+ def image_size(self):
521
+ """Image size for the SAM3 model."""
522
+ return self.vision_config.image_size
523
+
524
+ @image_size.setter
525
+ def image_size(self, value):
526
+ """Set the image size and propagate to vision config."""
527
+ self.vision_config.image_size = value
528
+
509
529
 
510
530
  __all__ = [
511
531
  "Sam3Config",
@@ -417,6 +417,10 @@ class Sam3ViTRotaryEmbedding(nn.Module):
417
417
  # Ensure even dimension for proper axial splitting
418
418
  if dim % 4 != 0:
419
419
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
420
+ self.end_x, self.end_y = end_x, end_y
421
+ self.dim = dim
422
+ self.rope_theta = config.rope_theta
423
+ self.scale = scale
420
424
  freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
421
425
 
422
426
  flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
@@ -776,6 +780,19 @@ class Sam3PreTrainedModel(PreTrainedModel):
776
780
  super()._init_weights(module)
777
781
  if isinstance(module, Sam3ViTEmbeddings):
778
782
  init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
783
+ elif isinstance(module, Sam3ViTRotaryEmbedding):
784
+ end_x, end_y = module.end_x, module.end_y
785
+ dim = module.dim
786
+ freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
787
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
788
+ x_positions = (flattened_indices % end_x) * module.scale
789
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
790
+ freqs_x = torch.outer(x_positions, freqs).float()
791
+ freqs_y = torch.outer(y_positions, freqs).float()
792
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
793
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
794
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
795
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
779
796
 
780
797
 
781
798
  @auto_docstring
@@ -1338,6 +1355,8 @@ class Sam3DetrEncoder(Sam3PreTrainedModel):
1338
1355
 
1339
1356
  self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
1340
1357
 
1358
+ self.post_init()
1359
+
1341
1360
  def _prepare_multilevel_features(
1342
1361
  self,
1343
1362
  vision_features: list[torch.Tensor],
@@ -1617,6 +1636,8 @@ class Sam3DetrDecoder(Sam3PreTrainedModel):
1617
1636
 
1618
1637
  self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
1619
1638
 
1639
+ self.post_init()
1640
+
1620
1641
  @compile_compatible_method_lru_cache(maxsize=1)
1621
1642
  def _get_coords(
1622
1643
  self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
@@ -1987,6 +2008,8 @@ class Sam3MaskDecoder(Sam3PreTrainedModel):
1987
2008
  self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
1988
2009
  self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
1989
2010
 
2011
+ self.post_init()
2012
+
1990
2013
  @check_model_inputs
1991
2014
  def forward(
1992
2015
  self,