transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,12 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
32
+ from ...integrations import (
33
+ use_experts_implementation,
34
+ use_kernel_forward_from_hub,
35
+ use_kernel_func_from_hub,
36
+ use_kernelized_func,
37
+ )
33
38
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
39
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
40
  from ...modeling_layers import GradientCheckpointingLayer
@@ -37,7 +42,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
42
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
43
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
44
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
45
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
46
  from ...utils.generic import check_model_inputs, maybe_autocast
42
47
  from .configuration_dots1 import Dots1Config
43
48
 
@@ -80,7 +85,7 @@ class Dots1RotaryEmbedding(nn.Module):
80
85
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
86
 
82
87
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
88
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
89
 
85
90
  @staticmethod
86
91
  def compute_default_rope_parameters(
@@ -308,6 +313,7 @@ class Dots1TopkRouter(nn.Module):
308
313
  return router_logits
309
314
 
310
315
 
316
+ @use_experts_implementation
311
317
  class Dots1NaiveMoe(nn.Module):
312
318
  """Collection of expert weights stored as 3D tensors."""
313
319
 
@@ -315,7 +321,7 @@ class Dots1NaiveMoe(nn.Module):
315
321
  super().__init__()
316
322
  self.num_experts = config.num_local_experts
317
323
  self.hidden_dim = config.hidden_size
318
- self.intermediate_dim = config.intermediate_size
324
+ self.intermediate_dim = config.moe_intermediate_size
319
325
  self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
320
326
  self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
321
327
  self.act_fn = ACT2FN[config.hidden_act]
@@ -463,7 +469,9 @@ class Dots1PreTrainedModel(PreTrainedModel):
463
469
  _supports_flash_attn = True
464
470
  _supports_sdpa = True
465
471
  _supports_flex_attn = True
466
- _can_compile_fullgraph = False
472
+ _can_compile_fullgraph = (
473
+ is_grouped_mm_available()
474
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
467
475
  _supports_attention_backend = True
468
476
  _can_record_outputs = {
469
477
  "hidden_states": Dots1DecoderLayer,
@@ -476,6 +484,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
476
484
  super()._init_weights(module)
477
485
  if isinstance(module, Dots1TopkRouter):
478
486
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
487
+ init.zeros_(module.e_score_correction_bias)
479
488
  elif isinstance(module, Dots1NaiveMoe):
480
489
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
481
490
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
@@ -102,7 +102,7 @@ class DPTConfig(PreTrainedConfig):
102
102
  Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
103
103
  neck_ignore_stages (`list[int]`, *optional*, defaults to `[0, 1]`):
104
104
  Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
105
- backbone_config (`Union[dict[str, Any], PreTrainedConfig]`, *optional*):
105
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `BitConfig()`):
106
106
  The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
107
107
  leverage the [`AutoBackbone`] API.
108
108
  backbone (`str`, *optional*):
@@ -225,8 +225,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
225
225
  processed_images_grouped[shape] = stacked_images
226
226
 
227
227
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
228
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
229
- return BatchFeature(data={"pixel_values": processed_images})
228
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
230
229
 
231
230
  def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
232
231
  """
@@ -228,8 +228,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
228
228
  processed_images_grouped[shape] = stacked_images
229
229
 
230
230
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
231
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
232
- return BatchFeature(data={"pixel_values": processed_images})
231
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
233
232
 
234
233
  def post_process_depth_estimation(
235
234
  self,
@@ -33,7 +33,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
33
33
  documentation from [`PreTrainedConfig`] for more information.
34
34
 
35
35
  Args:
36
- backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*):
36
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
37
37
  Configuration for the vision backbone. This is used to instantiate the backbone using
38
38
  `AutoModel.from_config`.
39
39
  backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
@@ -30,7 +30,7 @@ import torch.nn as nn
30
30
  import torch.nn.functional as F
31
31
  from torch import Tensor
32
32
 
33
- from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
33
+ from transformers.utils.generic import OutputRecorder
34
34
 
35
35
  from ... import initialization as init
36
36
  from ...activations import ACT2FN
@@ -39,6 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...pytorch_utils import compile_compatible_method_lru_cache
41
41
  from ...utils import ModelOutput, auto_docstring
42
+ from ...utils.generic import TransformersKwargs, check_model_inputs
42
43
  from ..auto import AutoModel
43
44
  from .configuration_edgetam import (
44
45
  EdgeTamConfig,
@@ -50,7 +51,7 @@ from .configuration_edgetam import (
50
51
 
51
52
  # fix this in modular
52
53
  if True:
53
- from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
54
+ from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
54
55
 
55
56
 
56
57
  class EdgeTamLayerNorm(nn.LayerNorm):
@@ -315,6 +316,8 @@ class EdgeTamPreTrainedModel(PreTrainedModel):
315
316
  if isinstance(module, EdgeTamModel):
316
317
  if module.no_memory_embedding is not None:
317
318
  init.zeros_(module.no_memory_embedding)
319
+ elif hasattr(module, "positional_embedding"):
320
+ init.normal_(module.positional_embedding, std=module.scale)
318
321
 
319
322
 
320
323
  # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
@@ -19,8 +19,17 @@ from typing import Optional, Union
19
19
  import torch
20
20
  import torch.utils.checkpoint
21
21
 
22
- from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
23
- from transformers.models.sam2.modeling_sam2 import (
22
+ from ... import initialization as init
23
+ from ...configuration_utils import PreTrainedConfig
24
+ from ...modeling_utils import PreTrainedModel
25
+ from ...processing_utils import Unpack
26
+ from ...utils import (
27
+ auto_docstring,
28
+ )
29
+ from ...utils.generic import TransformersKwargs, check_model_inputs
30
+ from ..auto import CONFIG_MAPPING, AutoConfig
31
+ from ..sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
32
+ from ..sam2.modeling_sam2 import (
24
33
  Sam2Attention,
25
34
  Sam2FeedForward,
26
35
  Sam2LayerNorm,
@@ -30,21 +39,11 @@ from transformers.models.sam2.modeling_sam2 import (
30
39
  Sam2VisionEncoderOutput,
31
40
  Sam2VisionModel,
32
41
  )
33
- from transformers.utils.generic import TransformersKwargs, check_model_inputs
34
-
35
- from ... import initialization as init
36
- from ...configuration_utils import PreTrainedConfig
37
- from ...modeling_utils import PreTrainedModel
38
- from ...processing_utils import Unpack
39
- from ...utils import (
40
- auto_docstring,
41
- )
42
- from ..auto import CONFIG_MAPPING, AutoConfig
43
42
 
44
43
 
45
44
  # fix this in modular
46
45
  if True:
47
- from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
46
+ from ..timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
48
47
 
49
48
 
50
49
  class EdgeTamVisionConfig(PreTrainedConfig):
@@ -58,7 +57,7 @@ class EdgeTamVisionConfig(PreTrainedConfig):
58
57
  documentation from [`PreTrainedConfig`] for more information.
59
58
 
60
59
  Args:
61
- backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*):
60
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `timm/repvit_m1.dist_in1k`):
62
61
  Configuration for the vision backbone. This is used to instantiate the backbone using
63
62
  `AutoModel.from_config`.
64
63
  backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
@@ -181,6 +180,8 @@ class EdgeTamPreTrainedModel(Sam2PreTrainedModel):
181
180
  if isinstance(module, EdgeTamModel):
182
181
  if module.no_memory_embedding is not None:
183
182
  init.zeros_(module.no_memory_embedding)
183
+ elif hasattr(module, "positional_embedding"):
184
+ init.normal_(module.positional_embedding, std=module.scale)
184
185
 
185
186
 
186
187
  @auto_docstring(
@@ -152,24 +152,17 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
152
152
 
153
153
  def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
154
154
  super().__init__()
155
- dim = config.memory_attention_hidden_size // (
155
+ self.dim = config.memory_attention_hidden_size // (
156
156
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
157
157
  )
158
158
  # Ensure even dimension for proper axial splitting
159
- if dim % 4 != 0:
159
+ if self.dim % 4 != 0:
160
160
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
161
- end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
162
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
161
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
162
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
163
163
 
164
- # Generate 2D position indices for axial rotary embedding
165
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
166
- x_positions = flattened_indices % end_x
167
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
168
- freqs_x = torch.outer(x_positions, freqs).float()
169
- freqs_y = torch.outer(y_positions, freqs).float()
170
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
171
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
172
164
  # directly register the cos and sin embeddings as we have a fixed feature shape
165
+ inv_freq = self.create_inv_freq()
173
166
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
174
167
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
175
168
 
@@ -178,6 +171,20 @@ class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
178
171
  # As the feature map size is fixed, we can just return the pre-computed embeddings.
179
172
  return self.rope_embeddings_cos, self.rope_embeddings_sin
180
173
 
174
+ def create_inv_freq(self):
175
+ freqs = 1.0 / (
176
+ self.memory_attention_rope_theta ** (torch.arange(0, self.dim, 4)[: (self.dim // 4)].float() / self.dim)
177
+ )
178
+ # Generate 2D position indices for axial rotary embedding
179
+ flattened_indices = torch.arange(self.end_x * self.end_y, dtype=torch.long)
180
+ x_positions = flattened_indices % self.end_x
181
+ y_positions = torch.div(flattened_indices, self.end_x, rounding_mode="floor")
182
+ freqs_x = torch.outer(x_positions, freqs).float()
183
+ freqs_y = torch.outer(y_positions, freqs).float()
184
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
185
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
186
+ return inv_freq
187
+
181
188
 
182
189
  def eager_attention_forward(
183
190
  module: nn.Module,
@@ -769,6 +776,31 @@ class EdgeTamVideoFeedForward(nn.Module):
769
776
  return hidden_states
770
777
 
771
778
 
779
+ class EdgeTamVideoPositionalEmbedding(nn.Module):
780
+ def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
781
+ super().__init__()
782
+ self.scale = config.scale
783
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
784
+ self.register_buffer("positional_embedding", positional_embedding)
785
+
786
+ def forward(self, input_coords, input_shape=None):
787
+ """Positionally encode points that are normalized to [0,1]."""
788
+ coordinates = input_coords.clone()
789
+
790
+ if input_shape is not None:
791
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
792
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
793
+ coordinates.to(torch.float32)
794
+
795
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
796
+ coordinates = 2 * coordinates - 1
797
+ coordinates = coordinates.to(self.positional_embedding.dtype)
798
+ coordinates = coordinates @ self.positional_embedding
799
+ coordinates = 2 * np.pi * coordinates
800
+ # outputs d_1 x ... x d_n x channel shape
801
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
802
+
803
+
772
804
  @auto_docstring
773
805
  class EdgeTamVideoPreTrainedModel(PreTrainedModel):
774
806
  config_class = EdgeTamVideoConfig
@@ -794,6 +826,16 @@ class EdgeTamVideoPreTrainedModel(PreTrainedModel):
794
826
  if isinstance(module, EdgeTamVideoMemoryFuserCXBlock):
795
827
  if module.scale is not None:
796
828
  init.zeros_(module.scale)
829
+ elif isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
830
+ inv_freq = module.create_inv_freq()
831
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
832
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
833
+ elif isinstance(module, EdgeTamVideoPositionalEmbedding):
834
+ init.normal_(module.positional_embedding, std=module.scale)
835
+ if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
836
+ inv_freq = module.create_inv_freq()
837
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
838
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
797
839
 
798
840
 
799
841
  class EdgeTamVideoInferenceCache:
@@ -959,7 +1001,7 @@ class EdgeTamVideoInferenceSession:
959
1001
  device_inputs = {}
960
1002
  for key, value in inputs.items():
961
1003
  if isinstance(value, torch.Tensor):
962
- device_inputs[key] = value.to(self.inference_device, non_blocking=True)
1004
+ device_inputs[key] = value.to(self.inference_device, non_blocking=False)
963
1005
  else:
964
1006
  device_inputs[key] = value
965
1007
  self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
@@ -1547,31 +1589,6 @@ class EdgeTamVideoSegmentationOutput(ModelOutput):
1547
1589
  frame_idx: Optional[int] = None
1548
1590
 
1549
1591
 
1550
- class EdgeTamVideoPositionalEmbedding(nn.Module):
1551
- def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
1552
- super().__init__()
1553
- self.scale = config.scale
1554
- positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
1555
- self.register_buffer("positional_embedding", positional_embedding)
1556
-
1557
- def forward(self, input_coords, input_shape=None):
1558
- """Positionally encode points that are normalized to [0,1]."""
1559
- coordinates = input_coords.clone()
1560
-
1561
- if input_shape is not None:
1562
- coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
1563
- coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
1564
- coordinates.to(torch.float32)
1565
-
1566
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1567
- coordinates = 2 * coordinates - 1
1568
- coordinates = coordinates.to(self.positional_embedding.dtype)
1569
- coordinates = coordinates @ self.positional_embedding
1570
- coordinates = 2 * np.pi * coordinates
1571
- # outputs d_1 x ... x d_n x channel shape
1572
- return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
1573
-
1574
-
1575
1592
  class EdgeTamVideoMaskEmbedding(nn.Module):
1576
1593
  def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
1577
1594
  super().__init__()
@@ -1976,11 +1993,6 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
1976
1993
  input_modalities = ("video", "text")
1977
1994
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
1978
1995
  _keys_to_ignore_on_load_unexpected = []
1979
- _tied_weights_keys = {
1980
- "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
1981
- }
1982
- # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1983
- _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1984
1996
 
1985
1997
  def __init__(self, config: EdgeTamVideoConfig):
1986
1998
  super().__init__(config)
@@ -29,6 +29,7 @@ from transformers.models.sam2.modeling_sam2 import (
29
29
  )
30
30
  from transformers.utils.generic import OutputRecorder
31
31
 
32
+ from ... import initialization as init
32
33
  from ...activations import ACT2FN
33
34
  from ...configuration_utils import PreTrainedConfig
34
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -375,24 +376,17 @@ class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput):
375
376
  class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding):
376
377
  def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
377
378
  nn.Module.__init__()
378
- dim = config.memory_attention_hidden_size // (
379
+ self.dim = config.memory_attention_hidden_size // (
379
380
  config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
380
381
  )
381
382
  # Ensure even dimension for proper axial splitting
382
- if dim % 4 != 0:
383
+ if self.dim % 4 != 0:
383
384
  raise ValueError("Dimension must be divisible by 4 for axial RoPE")
384
- end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
385
- freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
386
-
387
- # Generate 2D position indices for axial rotary embedding
388
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
389
- x_positions = flattened_indices % end_x
390
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
391
- freqs_x = torch.outer(x_positions, freqs).float()
392
- freqs_y = torch.outer(y_positions, freqs).float()
393
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
394
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
385
+ self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
386
+ self.memory_attention_rope_theta = config.memory_attention_rope_theta
387
+
395
388
  # directly register the cos and sin embeddings as we have a fixed feature shape
389
+ inv_freq = self.create_inv_freq()
396
390
  self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
397
391
  self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
398
392
 
@@ -662,7 +656,12 @@ class EdgeTamVideoFeedForward(Sam2VideoFeedForward):
662
656
 
663
657
 
664
658
  class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel):
665
- pass
659
+ def _init_weights(self, module):
660
+ super()._init_weights()
661
+ if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
662
+ inv_freq = module.create_inv_freq()
663
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
664
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
666
665
 
667
666
 
668
667
  class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession):
@@ -1040,11 +1039,6 @@ class EdgeTamVideoSegmentationOutput(Sam2VideoSegmentationOutput):
1040
1039
 
1041
1040
  @auto_docstring
1042
1041
  class EdgeTamVideoModel(Sam2VideoModel):
1043
- _tied_weights_keys = {
1044
- "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
1045
- }
1046
- # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1047
- _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1048
1042
  _keys_to_ignore_on_load_unexpected = []
1049
1043
  _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
1050
1044
 
@@ -153,9 +153,8 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
153
153
  stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
154
154
 
155
155
  # Return in same format as slow processor
156
- image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
157
156
 
158
- return BatchFeature(data={"pixel_values": image_pairs})
157
+ return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)
159
158
 
160
159
  def post_process_keypoint_matching(
161
160
  self,
@@ -103,7 +103,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
103
103
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
104
104
 
105
105
  self.register_buffer("inv_freq", inv_freq, persistent=False)
106
- self.original_inv_freq = inv_freq
106
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
107
107
 
108
108
  @staticmethod
109
109
  # Ignore copy
@@ -684,9 +684,22 @@ class EfficientLoFTRPreTrainedModel(PreTrainedModel):
684
684
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
685
685
  if module.bias is not None:
686
686
  init.zeros_(module.bias)
687
+ if getattr(module, "running_mean", None) is not None:
688
+ init.zeros_(module.running_mean)
689
+ init.ones_(module.running_var)
690
+ init.zeros_(module.num_batches_tracked)
687
691
  elif isinstance(module, nn.LayerNorm):
688
692
  init.zeros_(module.bias)
689
693
  init.ones_(module.weight)
694
+ elif isinstance(module, EfficientLoFTRRotaryEmbedding):
695
+ rope_fn = (
696
+ ROPE_INIT_FUNCTIONS[module.rope_type]
697
+ if module.rope_type != "default"
698
+ else module.compute_default_rope_parameters
699
+ )
700
+ buffer_value, _ = rope_fn(module.config)
701
+ init.copy_(module.inv_freq, buffer_value)
702
+ init.copy_(module.original_inv_freq, buffer_value)
690
703
 
691
704
  # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
692
705
  def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
@@ -66,7 +66,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
66
66
  `do_resize` in `preprocess`.
67
67
  size (`dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
68
68
  Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
69
- resample (`PILImageResampling` filter, *optional*, defaults to 0):
69
+ resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
70
70
  Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
71
71
  do_center_crop (`bool`, *optional*, defaults to `False`):
72
72
  Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
@@ -102,7 +102,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
102
102
  self,
103
103
  do_resize: bool = True,
104
104
  size: Optional[dict[str, int]] = None,
105
- resample: PILImageResampling = PIL.Image.NEAREST,
105
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
106
106
  do_center_crop: bool = False,
107
107
  crop_size: Optional[dict[str, int]] = None,
108
108
  rescale_factor: Union[int, float] = 1 / 255,
@@ -133,12 +133,11 @@ class EfficientNetImageProcessor(BaseImageProcessor):
133
133
  self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
134
134
  self.include_top = include_top
135
135
 
136
- # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST
137
136
  def resize(
138
137
  self,
139
138
  image: np.ndarray,
140
139
  size: dict[str, int],
141
- resample: PILImageResampling = PILImageResampling.NEAREST,
140
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
142
141
  data_format: Optional[Union[str, ChannelDimension]] = None,
143
142
  input_data_format: Optional[Union[str, ChannelDimension]] = None,
144
143
  **kwargs,
@@ -151,8 +150,8 @@ class EfficientNetImageProcessor(BaseImageProcessor):
151
150
  Image to resize.
152
151
  size (`dict[str, int]`):
153
152
  Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
154
- resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.NEAREST`):
155
- `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.NEAREST`.
153
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
154
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
156
155
  data_format (`ChannelDimension` or `str`, *optional*):
157
156
  The channel dimension format for the output image. If unset, the channel dimension format of the input
158
157
  image is used. Can be one of:
@@ -33,7 +33,7 @@ from .image_processing_efficientnet import EfficientNetImageProcessorKwargs
33
33
 
34
34
  @auto_docstring
35
35
  class EfficientNetImageProcessorFast(BaseImageProcessorFast):
36
- resample = PILImageResampling.NEAREST
36
+ resample = PILImageResampling.BICUBIC
37
37
  image_mean = IMAGENET_STANDARD_MEAN
38
38
  image_std = IMAGENET_STANDARD_STD
39
39
  size = {"height": 346, "width": 346}
@@ -178,7 +178,6 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
178
178
  processed_images_grouped[shape] = stacked_images
179
179
 
180
180
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
181
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
182
181
 
183
182
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
184
183
 
@@ -435,7 +435,7 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
435
435
  base_model_prefix = "efficientnet"
436
436
  main_input_name = "pixel_values"
437
437
  input_modalities = ("image",)
438
- _no_split_modules = []
438
+ _no_split_modules = ["EfficientNetBlock"]
439
439
 
440
440
  @torch.no_grad()
441
441
  def _init_weights(self, module: nn.Module):
@@ -444,6 +444,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
444
444
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
445
445
  if module.bias is not None:
446
446
  init.zeros_(module.bias)
447
+ if getattr(module, "running_mean", None) is not None:
448
+ init.zeros_(module.running_mean)
449
+ init.ones_(module.running_var)
450
+ init.zeros_(module.num_batches_tracked)
447
451
 
448
452
 
449
453
  @auto_docstring
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN, get_activation
26
27
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
28
  from ...generation import GenerationMixin
@@ -532,6 +533,12 @@ class ElectraPreTrainedModel(PreTrainedModel):
532
533
  "cross_attentions": ElectraCrossAttention,
533
534
  }
534
535
 
536
+ def _init_weights(self, module):
537
+ super()._init_weights(module)
538
+ if isinstance(module, ElectraEmbeddings):
539
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
540
+ init.zeros_(module.token_type_ids)
541
+
535
542
 
536
543
  @dataclass
537
544
  @auto_docstring(
@@ -958,6 +958,10 @@ class Emu3VQVAE(PreTrainedModel):
958
958
  elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
959
959
  init.constant_(module.weight, 1.0)
960
960
  init.constant_(module.bias, 0.0)
961
+ if getattr(module, "running_mean", None) is not None:
962
+ init.zeros_(module.running_mean)
963
+ init.ones_(module.running_var)
964
+ init.zeros_(module.num_batches_tracked)
961
965
  elif isinstance(module, nn.Embedding):
962
966
  init.normal_(module.weight)
963
967
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
@@ -1128,7 +1132,7 @@ class Emu3RotaryEmbedding(nn.Module):
1128
1132
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
1129
1133
 
1130
1134
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1131
- self.original_inv_freq = inv_freq
1135
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
1132
1136
 
1133
1137
  @staticmethod
1134
1138
  def compute_default_rope_parameters(
@@ -1615,6 +1619,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
1615
1619
  position_ids=None,
1616
1620
  use_cache=True,
1617
1621
  pixel_values=None,
1622
+ is_first_iteration=False,
1618
1623
  **kwargs,
1619
1624
  ):
1620
1625
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1628,10 +1633,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
1628
1633
  position_ids=position_ids,
1629
1634
  pixel_values=pixel_values,
1630
1635
  use_cache=use_cache,
1636
+ is_first_iteration=is_first_iteration,
1631
1637
  **kwargs,
1632
1638
  )
1633
1639
 
1634
- if cache_position[0] != 0:
1640
+ if not is_first_iteration and use_cache:
1635
1641
  model_inputs["pixel_values"] = None
1636
1642
 
1637
1643
  return model_inputs
@@ -706,6 +706,10 @@ class Emu3VQVAE(PreTrainedModel):
706
706
  elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
707
707
  init.constant_(module.weight, 1.0)
708
708
  init.constant_(module.bias, 0.0)
709
+ if getattr(module, "running_mean", None) is not None:
710
+ init.zeros_(module.running_mean)
711
+ init.ones_(module.running_var)
712
+ init.zeros_(module.num_batches_tracked)
709
713
  elif isinstance(module, nn.Embedding):
710
714
  init.normal_(module.weight)
711
715
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
@@ -1167,6 +1171,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
1167
1171
  position_ids=None,
1168
1172
  use_cache=True,
1169
1173
  pixel_values=None,
1174
+ is_first_iteration=False,
1170
1175
  **kwargs,
1171
1176
  ):
1172
1177
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1180,10 +1185,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
1180
1185
  position_ids=position_ids,
1181
1186
  pixel_values=pixel_values,
1182
1187
  use_cache=use_cache,
1188
+ is_first_iteration=is_first_iteration,
1183
1189
  **kwargs,
1184
1190
  )
1185
1191
 
1186
- if cache_position[0] != 0:
1192
+ if not is_first_iteration and use_cache:
1187
1193
  model_inputs["pixel_values"] = None
1188
1194
 
1189
1195
  return model_inputs