transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -128,6 +128,8 @@ class Sam3TrackerPreTrainedModel(PreTrainedModel):
128
128
  if isinstance(module, Sam3TrackerModel):
129
129
  if module.no_memory_embedding is not None:
130
130
  init.zeros_(module.no_memory_embedding)
131
+ elif isinstance(module, Sam3TrackerPositionalEmbedding):
132
+ init.normal_(module.positional_embedding, std=module.scale)
131
133
 
132
134
 
133
135
  class Sam3TrackerPositionalEmbedding(nn.Module):
@@ -149,6 +149,8 @@ class Sam3TrackerPreTrainedModel(Sam2PreTrainedModel):
149
149
  if isinstance(module, Sam3TrackerModel):
150
150
  if module.no_memory_embedding is not None:
151
151
  init.zeros_(module.no_memory_embedding)
152
+ elif isinstance(module, Sam3TrackerPositionalEmbedding):
153
+ init.normal_(module.positional_embedding, std=module.scale)
152
154
 
153
155
 
154
156
  class Sam3TrackerPositionalEmbedding(Sam2PositionalEmbedding):
@@ -397,5 +397,30 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
397
397
 
398
398
  super().__init__(**kwargs)
399
399
 
400
+ @property
401
+ def image_size(self):
402
+ """Image size for the tracker video model."""
403
+ return self.vision_config.image_size
404
+
405
+ @image_size.setter
406
+ def image_size(self, value):
407
+ """Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
408
+ self.prompt_encoder_config.image_size = value
409
+ self.vision_config.image_size = value
410
+
411
+ patch_size = self.vision_config.backbone_config.patch_size
412
+ self.vision_config.backbone_feature_sizes = [
413
+ [4 * value // patch_size, 4 * value // patch_size],
414
+ [2 * value // patch_size, 2 * value // patch_size],
415
+ [value // patch_size, value // patch_size],
416
+ ]
417
+ self.memory_attention_rope_feat_sizes = [
418
+ value // patch_size,
419
+ value // patch_size,
420
+ ]
421
+
422
+ # keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
423
+ self.__dict__["image_size"] = value
424
+
400
425
 
401
426
  __all__ = ["Sam3TrackerVideoMaskDecoderConfig", "Sam3TrackerVideoPromptEncoderConfig", "Sam3TrackerVideoConfig"]
@@ -213,7 +213,7 @@ class Sam3TrackerVideoInferenceSession:
213
213
  device_inputs = {}
214
214
  for key, value in inputs.items():
215
215
  if isinstance(value, torch.Tensor):
216
- device_inputs[key] = value.to(self.inference_device, non_blocking=True)
216
+ device_inputs[key] = value.to(self.inference_device, non_blocking=False)
217
217
  else:
218
218
  device_inputs[key] = value
219
219
  self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
@@ -692,6 +692,12 @@ class Sam3TrackerVideoPreTrainedModel(PreTrainedModel):
692
692
  if isinstance(module, Sam3TrackerVideoMemoryFuserCXBlock):
693
693
  if module.scale is not None:
694
694
  init.zeros_(module.scale)
695
+ elif isinstance(module, Sam3TrackerVideoVisionRotaryEmbedding):
696
+ inv_freq = module.create_inv_freq()
697
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
698
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
699
+ elif isinstance(module, Sam3TrackerVideoPositionalEmbedding):
700
+ init.normal_(module.positional_embedding, std=module.scale)
695
701
 
696
702
 
697
703
  class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
@@ -702,24 +708,17 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
702
708
 
703
709
  def __init__(self, config: Sam3TrackerVideoConfig):
704
710
  super().__init__()
705
- dim = config.memory_attention_hidden_size // (
711
+ self.dim = config.memory_attention_hidden_size // (
706
712
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
707
713
  )
708
714
  # Ensure even dimension for proper axial splitting
709
- if dim % 4 != 0:
715
+ if self.dim % 4 != 0:
710
716
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
711
- end_x, end_y = config.memory_attention_rope_feat_sizes
712
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
717
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes
718
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
713
719
 
714
- # Generate 2D position indices for axial rotary embedding
715
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
716
- x_positions = flattened_indices % end_x
717
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
718
- freqs_x = torch.outer(x_positions, freqs).float()
719
- freqs_y = torch.outer(y_positions, freqs).float()
720
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
721
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
722
720
  # directly register the cos and sin embeddings as we have a fixed feature shape
721
+ inv_freq = self.create_inv_freq()
723
722
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
724
723
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
725
724
 
@@ -728,6 +727,20 @@ class Sam3TrackerVideoVisionRotaryEmbedding(nn.Module):
728
727
  # As the feature map size is fixed, we can just return the pre-computed embeddings.
729
728
  return self.rope_embeddings_cos, self.rope_embeddings_sin
730
729
 
730
+ def create_inv_freq(self):
731
+ freqs = 1.0 / (
732
+ self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
733
+ )
734
+ # Generate 2D position indices for axial rotary embedding
735
+ flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
736
+ x_positions = flattened_indices % self.end_x
737
+ y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
738
+ freqs_x = torch.outer(x_positions, freqs).float()
739
+ freqs_y = torch.outer(y_positions, freqs).float()
740
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
741
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
742
+ return inv_freq
743
+
731
744
 
732
745
  def rotate_pairwise(x):
733
746
  """
@@ -1567,8 +1580,6 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
1567
1580
  input_modalities = ("video", "text")
1568
1581
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam3TrackerVideoTwoWayAttentionBlock, index=2)}
1569
1582
  _keys_to_ignore_on_load_unexpected = [r"^detector_model."]
1570
- _tied_weights_keys = {}
1571
- _keys_to_ignore_on_load_missing = []
1572
1583
  _checkpoint_conversion_mapping = {
1573
1584
  r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
1574
1585
  "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
@@ -353,6 +353,31 @@ class Sam3TrackerVideoConfig(PreTrainedConfig):
353
353
 
354
354
  super().__init__(**kwargs)
355
355
 
356
+ @property
357
+ def image_size(self):
358
+ """Image size for the tracker video model."""
359
+ return self.vision_config.image_size
360
+
361
+ @image_size.setter
362
+ def image_size(self, value):
363
+ """Set the image size and propagate to sub-configs. Calculates feature sizes based on patch_size."""
364
+ self.prompt_encoder_config.image_size = value
365
+ self.vision_config.image_size = value
366
+
367
+ patch_size = self.vision_config.backbone_config.patch_size
368
+ self.vision_config.backbone_feature_sizes = [
369
+ [4 * value // patch_size, 4 * value // patch_size],
370
+ [2 * value // patch_size, 2 * value // patch_size],
371
+ [value // patch_size, value // patch_size],
372
+ ]
373
+ self.memory_attention_rope_feat_sizes = [
374
+ value // patch_size,
375
+ value // patch_size,
376
+ ]
377
+
378
+ # keep the image_size in the __dict__ to save the value in the config file (backward compatibility)
379
+ self.__dict__["image_size"] = value
380
+
356
381
 
357
382
  class Sam3TrackerVideoInferenceCache(Sam2VideoInferenceCache):
358
383
  pass
@@ -461,8 +486,6 @@ class Sam3TrackerVideoModel(Sam2VideoModel):
461
486
  "tracker_neck.": "vision_encoder.neck.",
462
487
  }
463
488
  _keys_to_ignore_on_load_unexpected = [r"^detector_model."]
464
- _tied_weights_keys = {}
465
- _keys_to_ignore_on_load_missing = []
466
489
 
467
490
  def __init__(self, config: Sam3TrackerVideoConfig, remove_vision_encoder: bool = False):
468
491
  r"""
@@ -96,6 +96,9 @@ class Sam3VideoConfig(PreTrainedConfig):
96
96
  >>> # Initializing a SAM3 Video configuration with default detector and tracker
97
97
  >>> configuration = Sam3VideoConfig()
98
98
 
99
+ >>> # Changing image size for custom resolution inference (automatically propagates to all nested configs)
100
+ >>> configuration.image_size = 560
101
+
99
102
  >>> # Initializing a model from the configuration
100
103
  >>> model = Sam3VideoModel(configuration)
101
104
 
@@ -225,5 +228,16 @@ class Sam3VideoConfig(PreTrainedConfig):
225
228
  self.high_conf_thresh = high_conf_thresh
226
229
  self.high_iou_thresh = high_iou_thresh
227
230
 
231
+ @property
232
+ def image_size(self):
233
+ """Image size for the video model."""
234
+ return self.detector_config.image_size
235
+
236
+ @image_size.setter
237
+ def image_size(self, value):
238
+ """Recursively propagate the image size to detector and tracker configs."""
239
+ self.detector_config.image_size = value
240
+ self.tracker_config.image_size = value
241
+
228
242
 
229
243
  __all__ = ["Sam3VideoConfig"]
@@ -33,7 +33,7 @@ from .configuration_sam3_video import Sam3VideoConfig
33
33
 
34
34
 
35
35
  if is_kernels_available():
36
- from kernels import get_kernel
36
+ from ...integrations.hub_kernels import get_kernel
37
37
 
38
38
  logger = logging.get_logger(__name__)
39
39
 
@@ -505,8 +505,6 @@ class Sam3VideoPreTrainedModel(PreTrainedModel):
505
505
 
506
506
  @auto_docstring
507
507
  class Sam3VideoModel(Sam3VideoPreTrainedModel):
508
- all_tied_weights_keys = {}
509
-
510
508
  def __init__(self, config: Sam3VideoConfig):
511
509
  super().__init__(config)
512
510
  self.config = config
@@ -542,6 +540,8 @@ class Sam3VideoModel(Sam3VideoPreTrainedModel):
542
540
 
543
541
  self.tracker_neck = Sam3VisionNeck(config.detector_config.vision_config)
544
542
 
543
+ self.post_init()
544
+
545
545
  def get_vision_features_for_tracker(self, vision_embeds: torch.Tensor):
546
546
  hidden_states = vision_embeds.last_hidden_state
547
547
  batch_size = hidden_states.shape[0]
@@ -340,7 +340,7 @@ class Sam3VideoProcessor(ProcessorMixin):
340
340
 
341
341
  # slice those valid entries from the original outputs
342
342
  keep_idx = torch.nonzero(keep, as_tuple=True)[0]
343
- keep_idx_gpu = keep_idx.pin_memory().to(device=out_binary_masks.device, non_blocking=True)
343
+ keep_idx_gpu = keep_idx.to(device=out_binary_masks.device, non_blocking=True)
344
344
 
345
345
  out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx)
346
346
  out_probs = torch.index_select(out_probs, 0, keep_idx)
@@ -188,6 +188,7 @@ class SamHQVisionConfig(PreTrainedConfig):
188
188
  self.global_attn_indexes = global_attn_indexes
189
189
  self.num_pos_feats = num_pos_feats
190
190
  self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
191
+ self.scale = self.hidden_size // 2
191
192
 
192
193
 
193
194
  class SamHQMaskDecoderConfig(PreTrainedConfig):
@@ -413,6 +413,29 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
413
413
  return hidden_states
414
414
 
415
415
 
416
+ class SamHQPositionalEmbedding(nn.Module):
417
+ def __init__(self, config):
418
+ super().__init__()
419
+ self.scale = config.scale
420
+ self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
421
+
422
+ def forward(self, input_coords, input_shape=None):
423
+ """Positionally encode points that are normalized to [0,1]."""
424
+ coordinates = input_coords.clone()
425
+
426
+ if input_shape is not None:
427
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
428
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
429
+
430
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
431
+ coordinates = 2 * coordinates - 1
432
+ coordinates = coordinates.to(self.positional_embedding.dtype)
433
+ coordinates = coordinates @ self.positional_embedding
434
+ coordinates = 2 * np.pi * coordinates
435
+ # outputs d_1 x ... x d_n x channel shape
436
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
437
+
438
+
416
439
  @auto_docstring
417
440
  class SamHQPreTrainedModel(PreTrainedModel):
418
441
  config: SamHQConfig
@@ -433,6 +456,8 @@ class SamHQPreTrainedModel(PreTrainedModel):
433
456
  elif isinstance(module, SamHQVisionEncoder):
434
457
  if self.config.use_abs_pos:
435
458
  init.zeros_(module.pos_embed)
459
+ elif isinstance(module, SamHQPositionalEmbedding):
460
+ init.normal_(module.positional_embedding, std=module.scale)
436
461
 
437
462
 
438
463
  class SamHQPatchEmbeddings(nn.Module):
@@ -525,6 +550,7 @@ class SamHQVisionEncoder(SamHQPreTrainedModel):
525
550
  self.neck = SamHQVisionNeck(config)
526
551
 
527
552
  self.gradient_checkpointing = False
553
+ self.post_init()
528
554
 
529
555
  def get_input_embeddings(self):
530
556
  return self.patch_embed
@@ -1069,29 +1095,6 @@ class SamHQVisionModel(SamHQPreTrainedModel):
1069
1095
  return self.vision_encoder(pixel_values, **kwargs)
1070
1096
 
1071
1097
 
1072
- class SamHQPositionalEmbedding(nn.Module):
1073
- def __init__(self, config):
1074
- super().__init__()
1075
- self.scale = config.hidden_size // 2
1076
- self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
1077
-
1078
- def forward(self, input_coords, input_shape=None):
1079
- """Positionally encode points that are normalized to [0,1]."""
1080
- coordinates = input_coords.clone()
1081
-
1082
- if input_shape is not None:
1083
- coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
1084
- coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
1085
-
1086
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1087
- coordinates = 2 * coordinates - 1
1088
- coordinates = coordinates.to(self.positional_embedding.dtype)
1089
- coordinates = coordinates @ self.positional_embedding
1090
- coordinates = 2 * np.pi * coordinates
1091
- # outputs d_1 x ... x d_n x channel shape
1092
- return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
1093
-
1094
-
1095
1098
  class SamHQMaskEmbedding(nn.Module):
1096
1099
  def __init__(self, config: SamHQPromptEncoderConfig):
1097
1100
  super().__init__()
@@ -287,18 +287,17 @@ class SeamlessM4TConformerRelPositionalEmbedding(nn.Module):
287
287
  super().__init__()
288
288
  self.max_len = config.max_source_positions
289
289
  self.d_model = config.hidden_size
290
- self.pe = None
291
- self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
290
+ self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
292
291
 
293
- def extend_pe(self, x):
292
+ def extend_pe(self, x, pe=None):
294
293
  # Reset the positional encodings
295
- if self.pe is not None:
294
+ if pe is not None:
296
295
  # self.pe contains both positive and negative parts
297
296
  # the length of self.pe is 2 * input_len - 1
298
- if self.pe.size(1) >= x.size(1) * 2 - 1:
299
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
300
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
301
- return
297
+ if pe.size(1) >= x.size(1) * 2 - 1:
298
+ if pe.dtype != x.dtype or pe.device != x.device:
299
+ pe = pe.to(dtype=x.dtype, device=x.device)
300
+ return pe
302
301
  # Suppose `i` is the position of query vector and `j` is the
303
302
  # position of key vector. We use positive relative positions when keys
304
303
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -319,10 +318,10 @@ class SeamlessM4TConformerRelPositionalEmbedding(nn.Module):
319
318
  pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
320
319
  pe_negative = pe_negative[1:].unsqueeze(0)
321
320
  pe = torch.cat([pe_positive, pe_negative], dim=1)
322
- self.pe = pe.to(device=x.device, dtype=x.dtype)
321
+ return pe.to(device=x.device, dtype=x.dtype)
323
322
 
324
323
  def forward(self, hidden_states: torch.Tensor):
325
- self.extend_pe(hidden_states)
324
+ self.pe = self.extend_pe(hidden_states, self.pe)
326
325
  start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
327
326
  end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
328
327
  relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -884,13 +883,14 @@ class SeamlessM4TScaledWordEmbedding(nn.Embedding):
884
883
  return super().forward(input_ids) * self.embed_scale
885
884
 
886
885
 
887
- # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
886
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding with M2M100->SeamlessM4T
888
887
  class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module):
889
888
  """This module produces sinusoidal positional embeddings of any length."""
890
889
 
891
890
  def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
892
891
  super().__init__()
893
892
  self.offset = 2
893
+ self.num_positions = num_positions
894
894
  self.embedding_dim = embedding_dim
895
895
  self.padding_idx = padding_idx
896
896
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
@@ -1375,11 +1375,27 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel):
1375
1375
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
1376
1376
  init.zeros_(module.bias)
1377
1377
  init.ones_(module.weight)
1378
+ if getattr(module, "running_mean", None) is not None:
1379
+ init.zeros_(module.running_mean)
1380
+ init.ones_(module.running_var)
1381
+ init.zeros_(module.num_batches_tracked)
1378
1382
  elif isinstance(module, nn.Conv1d):
1379
1383
  init.kaiming_normal_(module.weight)
1380
1384
  if module.bias is not None:
1381
1385
  k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1382
1386
  init.uniform_(module.bias, a=-k, b=k)
1387
+ elif isinstance(module, SeamlessM4TSinusoidalPositionalEmbedding):
1388
+ emb_weights = module.get_embedding(
1389
+ module.num_positions + module.offset, module.embedding_dim, module.padding_idx
1390
+ )
1391
+ init.copy_(module.weights, emb_weights)
1392
+ elif isinstance(module, SeamlessM4TConformerRotaryPositionalEmbedding):
1393
+ dim = self.config.hidden_size // self.config.speech_encoder_attention_heads
1394
+ base = self.config.rotary_embedding_base
1395
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
1396
+ init.copy_(module.inv_freq, inv_freq)
1397
+ elif isinstance(module, SeamlessM4TConformerRelPositionalEmbedding):
1398
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
1383
1399
 
1384
1400
  def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask):
1385
1401
  kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride
@@ -762,6 +762,7 @@ class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module):
762
762
  def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
763
763
  super().__init__()
764
764
  self.offset = 2
765
+ self.num_positions = num_positions
765
766
  self.embedding_dim = embedding_dim
766
767
  self.padding_idx = padding_idx
767
768
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
@@ -1292,6 +1293,11 @@ class SeamlessM4Tv2PreTrainedModel(PreTrainedModel):
1292
1293
  if module.bias is not None:
1293
1294
  k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1294
1295
  init.uniform_(module.bias, a=-k, b=k)
1296
+ elif isinstance(module, SeamlessM4Tv2SinusoidalPositionalEmbedding):
1297
+ emb_weights = module.get_embedding(
1298
+ module.num_positions + module.offset, module.embedding_dim, module.padding_idx
1299
+ )
1300
+ init.copy_(module.weights, emb_weights)
1295
1301
 
1296
1302
  # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TPreTrainedModel._compute_sub_sample_lengths_from_attention_mask
1297
1303
  def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask):
@@ -311,7 +311,7 @@ class SeedOssRotaryEmbedding(nn.Module):
311
311
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
312
312
 
313
313
  self.register_buffer("inv_freq", inv_freq, persistent=False)
314
- self.original_inv_freq = inv_freq
314
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
315
315
 
316
316
  @staticmethod
317
317
  def compute_default_rope_parameters(
@@ -168,7 +168,6 @@ class SegformerImageProcessorFast(BaseImageProcessorFast):
168
168
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
169
169
 
170
170
  # Stack images into a single tensor if return_tensors is set
171
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
172
171
 
173
172
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
174
173
 
@@ -549,9 +549,9 @@ class SegformerMLP(nn.Module):
549
549
  return hidden_states
550
550
 
551
551
 
552
- class SegformerDecodeHead(SegformerPreTrainedModel):
552
+ class SegformerDecodeHead(nn.Module):
553
553
  def __init__(self, config):
554
- super().__init__(config)
554
+ super().__init__()
555
555
  # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
556
556
  mlps = []
557
557
  for i in range(config.num_encoder_blocks):
@@ -140,7 +140,6 @@ class SegformerImageProcessorFast(BeitImageProcessorFast):
140
140
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
141
141
 
142
142
  # Stack images into a single tensor if return_tensors is set
143
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
144
143
 
145
144
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
146
145
 
@@ -57,6 +57,7 @@ class ShieldGemma2ForImageClassification(PreTrainedModel):
57
57
  self.yes_token_index = getattr(config, "yes_token_index", 10_784)
58
58
  self.no_token_index = getattr(config, "no_token_index", 3771)
59
59
  self.model = AutoModelForImageTextToText.from_config(config=config)
60
+ self.post_init()
60
61
 
61
62
  def get_input_embeddings(self):
62
63
  return self.model.language_model.get_input_embeddings()
@@ -430,6 +430,8 @@ class SiglipPreTrainedModel(PreTrainedModel):
430
430
  else self.config.hidden_size
431
431
  )
432
432
  init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
433
+ if hasattr(module, "position_ids"):
434
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
433
435
  elif isinstance(module, nn.Embedding):
434
436
  default_flax_embed_init(module.weight)
435
437
  elif isinstance(module, SiglipAttention):
@@ -465,6 +467,8 @@ class SiglipPreTrainedModel(PreTrainedModel):
465
467
  elif isinstance(module, nn.LayerNorm):
466
468
  init.zeros_(module.bias)
467
469
  init.ones_(module.weight)
470
+ elif isinstance(module, SiglipTextEmbeddings):
471
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
468
472
 
469
473
 
470
474
  # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
@@ -502,9 +506,11 @@ class SiglipEncoder(nn.Module):
502
506
  return BaseModelOutput(last_hidden_state=hidden_states)
503
507
 
504
508
 
505
- class SiglipTextTransformer(nn.Module):
509
+ class SiglipTextTransformer(SiglipPreTrainedModel):
510
+ _input_embed_layer = "token_embedding"
511
+
506
512
  def __init__(self, config: SiglipTextConfig):
507
- super().__init__()
513
+ super().__init__(config)
508
514
  self.config = config
509
515
  embed_dim = config.hidden_size
510
516
  self.embeddings = SiglipTextEmbeddings(config)
@@ -512,6 +518,7 @@ class SiglipTextTransformer(nn.Module):
512
518
  self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
513
519
 
514
520
  self.head = nn.Linear(embed_dim, config.projection_size)
521
+ self.post_init()
515
522
 
516
523
  @can_return_tuple
517
524
  @auto_docstring
@@ -614,6 +621,7 @@ class SiglipTextModel(SiglipPreTrainedModel):
614
621
 
615
622
 
616
623
  class SiglipVisionTransformer(SiglipPreTrainedModel):
624
+ _input_embed_layer = "patch_embedding"
617
625
  _can_record_outputs = {
618
626
  "hidden_states": SiglipEncoderLayer,
619
627
  "attentions": SiglipAttention,
@@ -631,6 +639,8 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
631
639
  if self.use_head:
632
640
  self.head = SiglipMultiheadAttentionPoolingHead(config)
633
641
 
642
+ self.post_init()
643
+
634
644
  @check_model_inputs(tie_last_hidden_states=False)
635
645
  @auto_docstring
636
646
  def forward(
@@ -774,6 +784,12 @@ class SiglipModel(SiglipPreTrainedModel):
774
784
  # Initialize weights and apply final processing
775
785
  self.post_init()
776
786
 
787
+ def get_input_embeddings(self) -> nn.Module:
788
+ return self.text_model.embeddings.token_embedding
789
+
790
+ def set_input_embeddings(self, value: nn.Module):
791
+ self.text_model.embeddings.token_embedding = value
792
+
777
793
  @filter_out_non_signature_kwargs()
778
794
  @auto_docstring
779
795
  def get_text_features(
@@ -969,6 +985,12 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
969
985
  # Initialize weights and apply final processing
970
986
  self.post_init()
971
987
 
988
+ def get_input_embeddings(self) -> nn.Module:
989
+ return self.vision_model.embeddings.patch_embedding
990
+
991
+ def set_input_embeddings(self, value: nn.Module):
992
+ self.vision_model.embeddings.patch_embedding = value
993
+
972
994
  @check_model_inputs
973
995
  @auto_docstring
974
996
  def forward(