transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -215,6 +215,46 @@ class Siglip2VisionEmbeddings(nn.Module):
215
215
  return embeddings
216
216
 
217
217
 
218
+ class Siglip2TextEmbeddings(nn.Module):
219
+ def __init__(self, config: Siglip2TextConfig):
220
+ super().__init__()
221
+ embed_dim = config.hidden_size
222
+
223
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
224
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
225
+
226
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
227
+ self.register_buffer(
228
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
229
+ )
230
+
231
+ def forward(
232
+ self,
233
+ input_ids: Optional[torch.LongTensor] = None,
234
+ position_ids: Optional[torch.LongTensor] = None,
235
+ inputs_embeds: Optional[torch.FloatTensor] = None,
236
+ ) -> torch.Tensor:
237
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
238
+ max_position_embedding = self.position_embedding.weight.shape[0]
239
+
240
+ if seq_length > max_position_embedding:
241
+ raise ValueError(
242
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
243
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
244
+ )
245
+
246
+ if position_ids is None:
247
+ position_ids = self.position_ids[:, :seq_length]
248
+
249
+ if inputs_embeds is None:
250
+ inputs_embeds = self.token_embedding(input_ids)
251
+
252
+ position_embeddings = self.position_embedding(position_ids)
253
+ embeddings = inputs_embeds + position_embeddings
254
+
255
+ return embeddings
256
+
257
+
218
258
  def eager_attention_forward(
219
259
  module: nn.Module,
220
260
  query: torch.Tensor,
@@ -412,6 +452,8 @@ class Siglip2PreTrainedModel(PreTrainedModel):
412
452
  else self.config.hidden_size
413
453
  )
414
454
  init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
455
+ if hasattr(module, "position_ids"):
456
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
415
457
  elif isinstance(module, nn.Embedding):
416
458
  default_flax_embed_init(module.weight)
417
459
  elif isinstance(module, Siglip2Attention):
@@ -447,6 +489,8 @@ class Siglip2PreTrainedModel(PreTrainedModel):
447
489
  elif isinstance(module, nn.LayerNorm):
448
490
  init.zeros_(module.bias)
449
491
  init.ones_(module.weight)
492
+ elif isinstance(module, Siglip2TextEmbeddings):
493
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
450
494
 
451
495
 
452
496
  class Siglip2Encoder(nn.Module):
@@ -484,6 +528,7 @@ class Siglip2Encoder(nn.Module):
484
528
 
485
529
 
486
530
  class Siglip2VisionTransformer(Siglip2PreTrainedModel):
531
+ _input_embed_layer = "patch_embedding"
487
532
  _can_record_outputs = {
488
533
  "hidden_states": Siglip2EncoderLayer,
489
534
  "attentions": Siglip2Attention,
@@ -501,6 +546,8 @@ class Siglip2VisionTransformer(Siglip2PreTrainedModel):
501
546
  if self.use_head:
502
547
  self.head = Siglip2MultiheadAttentionPoolingHead(config)
503
548
 
549
+ self.post_init()
550
+
504
551
  @check_model_inputs(tie_last_hidden_states=False)
505
552
  @auto_docstring
506
553
  def forward(
@@ -549,49 +596,11 @@ class Siglip2VisionTransformer(Siglip2PreTrainedModel):
549
596
  )
550
597
 
551
598
 
552
- class Siglip2TextEmbeddings(nn.Module):
553
- def __init__(self, config: Siglip2TextConfig):
554
- super().__init__()
555
- embed_dim = config.hidden_size
599
+ class Siglip2TextTransformer(Siglip2PreTrainedModel):
600
+ _input_embed_layer = "token_embedding"
556
601
 
557
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
558
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
559
-
560
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
561
- self.register_buffer(
562
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
563
- )
564
-
565
- def forward(
566
- self,
567
- input_ids: Optional[torch.LongTensor] = None,
568
- position_ids: Optional[torch.LongTensor] = None,
569
- inputs_embeds: Optional[torch.FloatTensor] = None,
570
- ) -> torch.Tensor:
571
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
572
- max_position_embedding = self.position_embedding.weight.shape[0]
573
-
574
- if seq_length > max_position_embedding:
575
- raise ValueError(
576
- f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
577
- f"{seq_length} and max_position_embeddings: {max_position_embedding}"
578
- )
579
-
580
- if position_ids is None:
581
- position_ids = self.position_ids[:, :seq_length]
582
-
583
- if inputs_embeds is None:
584
- inputs_embeds = self.token_embedding(input_ids)
585
-
586
- position_embeddings = self.position_embedding(position_ids)
587
- embeddings = inputs_embeds + position_embeddings
588
-
589
- return embeddings
590
-
591
-
592
- class Siglip2TextTransformer(nn.Module):
593
602
  def __init__(self, config: Siglip2TextConfig):
594
- super().__init__()
603
+ super().__init__(config)
595
604
  self.config = config
596
605
  embed_dim = config.hidden_size
597
606
  self.embeddings = Siglip2TextEmbeddings(config)
@@ -599,6 +608,7 @@ class Siglip2TextTransformer(nn.Module):
599
608
  self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
600
609
 
601
610
  self.head = nn.Linear(embed_dim, config.projection_size)
611
+ self.post_init()
602
612
 
603
613
  @can_return_tuple
604
614
  @auto_docstring
@@ -833,6 +843,12 @@ class Siglip2Model(Siglip2PreTrainedModel):
833
843
  # Initialize weights and apply final processing
834
844
  self.post_init()
835
845
 
846
+ def get_input_embeddings(self) -> nn.Module:
847
+ return self.text_model.embeddings.token_embedding
848
+
849
+ def set_input_embeddings(self, value: nn.Module):
850
+ self.text_model.embeddings.token_embedding = value
851
+
836
852
  @filter_out_non_signature_kwargs()
837
853
  @auto_docstring
838
854
  def get_text_features(
@@ -1051,6 +1067,12 @@ class Siglip2ForImageClassification(Siglip2PreTrainedModel):
1051
1067
  # Initialize weights and apply final processing
1052
1068
  self.post_init()
1053
1069
 
1070
+ def get_input_embeddings(self) -> nn.Module:
1071
+ return self.vision_model.embeddings.patch_embedding
1072
+
1073
+ def set_input_embeddings(self, value: nn.Module):
1074
+ self.vision_model.embeddings.patch_embedding = value
1075
+
1054
1076
  @check_model_inputs
1055
1077
  @auto_docstring
1056
1078
  def forward(
@@ -63,7 +63,7 @@ class SmolLM3RotaryEmbedding(nn.Module):
63
63
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
64
64
 
65
65
  self.register_buffer("inv_freq", inv_freq, persistent=False)
66
- self.original_inv_freq = inv_freq
66
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
67
67
 
68
68
  @staticmethod
69
69
  def compute_default_rope_parameters(
@@ -330,6 +330,8 @@ class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
330
330
  self.patch_size = config.patch_size
331
331
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
332
332
 
333
+ self.post_init()
334
+
333
335
  def get_input_embeddings(self):
334
336
  return self.embeddings
335
337
 
@@ -853,6 +855,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
853
855
  pixel_attention_mask=None,
854
856
  image_hidden_states=None,
855
857
  logits_to_keep=None,
858
+ is_first_iteration=False,
856
859
  **kwargs,
857
860
  ):
858
861
  # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
@@ -868,10 +871,11 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
868
871
  pixel_attention_mask=pixel_attention_mask,
869
872
  image_hidden_states=image_hidden_states,
870
873
  logits_to_keep=logits_to_keep,
874
+ is_first_iteration=is_first_iteration,
871
875
  **kwargs,
872
876
  )
873
877
 
874
- if image_hidden_states is not None or cache_position[0] != 0:
878
+ if image_hidden_states is not None or not is_first_iteration:
875
879
  model_inputs["pixel_values"] = None
876
880
  model_inputs["pixel_attention_mask"] = None
877
881
 
@@ -331,7 +331,6 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
331
331
  processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
332
332
  pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index)
333
333
 
334
- processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
335
334
  data = {"pixel_values": processed_videos}
336
335
 
337
336
  if do_pad:
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import CrossEntropyLoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
28
  from ...generation import GenerationMixin
@@ -105,6 +106,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
105
106
  def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
106
107
  super().__init__()
107
108
  self.offset = 2
109
+ self.num_positions = num_positions
108
110
  self.embedding_dim = embedding_dim
109
111
  self.padding_idx = padding_idx
110
112
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
@@ -495,6 +497,14 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
495
497
  _supports_sdpa = False
496
498
  _supports_flex_attn = False
497
499
 
500
+ def _init_weights(self, module):
501
+ super()._init_weights(module)
502
+ if isinstance(module, Speech2TextSinusoidalPositionalEmbedding):
503
+ emb_weights = module.get_embedding(
504
+ module.num_positions + module.offset, module.embedding_dim, module.padding_idx
505
+ )
506
+ init.copy_(module.weights, emb_weights)
507
+
498
508
  def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
499
509
  """
500
510
  Computes the output length of the convolutional layers
@@ -290,6 +290,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
290
290
  def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
291
291
  super().__init__()
292
292
  self.offset = 2
293
+ self.num_positions = num_positions
293
294
  self.embedding_dim = embedding_dim
294
295
  self.padding_idx = padding_idx
295
296
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
@@ -414,6 +415,7 @@ class SpeechT5ScaledPositionalEncoding(nn.Module):
414
415
  self.register_buffer("pe", pe, persistent=False)
415
416
  self.dropout = nn.Dropout(p=dropout)
416
417
  self.dim = dim
418
+ self.max_len = max_len
417
419
  self.alpha = nn.Parameter(torch.tensor(1.0))
418
420
 
419
421
  def forward(self, emb):
@@ -1184,6 +1186,14 @@ class SpeechT5PreTrainedModel(PreTrainedModel):
1184
1186
  init.constant_(module.conv.bias, 0)
1185
1187
  elif isinstance(module, SpeechT5ScaledPositionalEncoding):
1186
1188
  init.ones_(module.alpha)
1189
+ dim, max_len = module.dim, module.max_len
1190
+ pe = torch.zeros(max_len, dim)
1191
+ position = torch.arange(0, max_len).unsqueeze(1)
1192
+ div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
1193
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
1194
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
1195
+ pe = pe.unsqueeze(0)
1196
+ init.copy_(module.pe, pe)
1187
1197
  elif isinstance(module, SpeechT5FeatureProjection):
1188
1198
  k = math.sqrt(1 / module.projection.in_features)
1189
1199
  init.uniform_(module.projection.weight, a=-k, b=k)
@@ -1195,6 +1205,10 @@ class SpeechT5PreTrainedModel(PreTrainedModel):
1195
1205
  elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
1196
1206
  init.zeros_(module.bias)
1197
1207
  init.ones_(module.weight)
1208
+ if getattr(module, "running_mean", None) is not None:
1209
+ init.zeros_(module.running_mean)
1210
+ init.ones_(module.running_var)
1211
+ init.zeros_(module.num_batches_tracked)
1198
1212
  elif isinstance(module, nn.Conv1d):
1199
1213
  init.kaiming_normal_(module.weight)
1200
1214
  if module.bias is not None:
@@ -1205,6 +1219,14 @@ class SpeechT5PreTrainedModel(PreTrainedModel):
1205
1219
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
1206
1220
  if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
1207
1221
  init.zeros_(module.weight[module.padding_idx])
1222
+ elif isinstance(module, SpeechT5SinusoidalPositionalEmbedding):
1223
+ emb_weights = module.get_embedding(
1224
+ module.num_positions + module.offset, module.embedding_dim, module.padding_idx
1225
+ )
1226
+ init.copy_(module.weights, emb_weights)
1227
+ elif isinstance(module, SpeechT5HifiGan):
1228
+ init.zeros_(module.mean)
1229
+ init.ones_(module.scale)
1208
1230
 
1209
1231
  if hasattr(module, "masked_spec_embed"):
1210
1232
  init.uniform_(module.masked_spec_embed)
@@ -3008,6 +3030,12 @@ class SpeechT5HifiGan(PreTrainedModel):
3008
3030
  # Initialize weights and apply final processing
3009
3031
  self.post_init()
3010
3032
 
3033
+ def _init_weights(self, module):
3034
+ super()._init_weights(module)
3035
+ if isinstance(module, SpeechT5HifiGan):
3036
+ init.zeros_(module.mean)
3037
+ init.ones_(module.scale)
3038
+
3011
3039
  def apply_weight_norm(self):
3012
3040
  weight_norm = nn.utils.weight_norm
3013
3041
  if hasattr(nn.utils.parametrizations, "weight_norm"):
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import CrossEntropyLoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...modeling_layers import GradientCheckpointingLayer
27
28
  from ...modeling_outputs import BaseModelOutput, ModelOutput, QuestionAnsweringModelOutput
@@ -305,9 +306,9 @@ class SplinterEncoder(nn.Module):
305
306
  all_hidden_states = all_hidden_states + (hidden_states,)
306
307
 
307
308
  layer_outputs = layer_module(
308
- hidden_states=hidden_states,
309
- attention_mask=attention_mask,
310
- output_attentions=output_attentions,
309
+ hidden_states,
310
+ attention_mask,
311
+ output_attentions,
311
312
  **kwargs,
312
313
  )
313
314
 
@@ -331,6 +332,11 @@ class SplinterPreTrainedModel(PreTrainedModel):
331
332
  base_model_prefix = "splinter"
332
333
  supports_gradient_checkpointing = True
333
334
 
335
+ def _init_weights(self, module):
336
+ super()._init_weights(module)
337
+ if isinstance(module, SplinterEmbeddings):
338
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
339
+
334
340
 
335
341
  @auto_docstring
336
342
  class SplinterModel(SplinterPreTrainedModel):
@@ -412,6 +412,8 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
412
412
  super()._init_weights(module)
413
413
  if isinstance(module, SqueezeBertLMPredictionHead):
414
414
  init.zeros_(module.bias)
415
+ elif isinstance(module, SqueezeBertEmbeddings):
416
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
415
417
 
416
418
 
417
419
  @auto_docstring
@@ -76,7 +76,7 @@ class StableLmRotaryEmbedding(nn.Module):
76
76
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
77
77
 
78
78
  self.register_buffer("inv_freq", inv_freq, persistent=False)
79
- self.original_inv_freq = inv_freq
79
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
80
80
 
81
81
  @staticmethod
82
82
  # Ignore copy
@@ -289,7 +289,7 @@ class Starcoder2RotaryEmbedding(nn.Module):
289
289
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
290
290
 
291
291
  self.register_buffer("inv_freq", inv_freq, persistent=False)
292
- self.original_inv_freq = inv_freq
292
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
293
293
 
294
294
  @staticmethod
295
295
  def compute_default_rope_parameters(
@@ -161,9 +161,8 @@ class SuperGlueImageProcessorFast(BaseImageProcessorFast):
161
161
  stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
162
162
 
163
163
  # Return in same format as slow processor
164
- image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
165
164
 
166
- return BatchFeature(data={"pixel_values": image_pairs})
165
+ return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)
167
166
 
168
167
  def post_process_keypoint_matching(
169
168
  self,
@@ -110,8 +110,7 @@ class SuperPointImageProcessorFast(BaseImageProcessorFast):
110
110
  stacked_images = self.rescale(stacked_images, rescale_factor)
111
111
  processed_images_grouped[shape] = stacked_images
112
112
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
113
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
114
- return BatchFeature(data={"pixel_values": processed_images})
113
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
115
114
 
116
115
  def post_process_keypoint_detection(
117
116
  self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, list[tuple]]
@@ -400,6 +400,10 @@ class SwiftFormerPreTrainedModel(PreTrainedModel):
400
400
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
401
401
  init.constant_(module.bias, 0)
402
402
  init.constant_(module.weight, 1.0)
403
+ if getattr(module, "running_mean", None) is not None:
404
+ init.zeros_(module.running_mean)
405
+ init.ones_(module.running_var)
406
+ init.zeros_(module.num_batches_tracked)
403
407
  elif isinstance(module, (SwiftFormerConvEncoder, SwiftFormerLocalRepresentation)):
404
408
  init.ones_(module.layer_scale)
405
409
  elif isinstance(module, SwiftFormerEncoderBlock):
@@ -411,18 +411,7 @@ class SwinSelfAttention(nn.Module):
411
411
  torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
412
412
  )
413
413
 
414
- # get pair-wise relative position index for each token inside the window
415
- coords_h = torch.arange(self.window_size[0])
416
- coords_w = torch.arange(self.window_size[1])
417
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
418
- coords_flatten = torch.flatten(coords, 1)
419
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
420
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
421
- relative_coords[:, :, 0] += self.window_size[0] - 1
422
- relative_coords[:, :, 1] += self.window_size[1] - 1
423
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
424
- relative_position_index = relative_coords.sum(-1)
425
- self.register_buffer("relative_position_index", relative_position_index)
414
+ self.register_buffer("relative_position_index", self.create_relative_position_index())
426
415
 
427
416
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
428
417
  self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -481,6 +470,20 @@ class SwinSelfAttention(nn.Module):
481
470
 
482
471
  return outputs
483
472
 
473
+ def create_relative_position_index(self):
474
+ # get pair-wise relative position index for each token inside the window
475
+ coords_h = torch.arange(self.window_size[0])
476
+ coords_w = torch.arange(self.window_size[1])
477
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
478
+ coords_flatten = torch.flatten(coords, 1)
479
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
480
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
481
+ relative_coords[:, :, 0] += self.window_size[0] - 1
482
+ relative_coords[:, :, 1] += self.window_size[1] - 1
483
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
484
+ relative_position_index = relative_coords.sum(-1)
485
+ return relative_position_index
486
+
484
487
 
485
488
  class SwinSelfOutput(nn.Module):
486
489
  def __init__(self, config, dim):
@@ -823,6 +826,7 @@ class SwinPreTrainedModel(PreTrainedModel):
823
826
  init.zeros_(module.position_embeddings)
824
827
  elif isinstance(module, SwinSelfAttention):
825
828
  init.zeros_(module.relative_position_bias_table)
829
+ init.copy_(module.relative_position_index, module.create_relative_position_index())
826
830
 
827
831
 
828
832
  @auto_docstring
@@ -97,7 +97,6 @@ class Swin2SRImageProcessorFast(BaseImageProcessorFast):
97
97
  stacked_images = self.pad(stacked_images, size_divisor=size_divisor)
98
98
  processed_image_grouped[shape] = stacked_images
99
99
  processed_images = reorder_images(processed_image_grouped, grouped_images_index)
100
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
101
100
 
102
101
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
103
102
 
@@ -250,40 +250,8 @@ class Swin2SRSelfAttention(nn.Module):
250
250
  nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
251
251
  )
252
252
 
253
- # get relative_coords_table
254
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
255
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
256
- relative_coords_table = (
257
- torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
258
- .permute(1, 2, 0)
259
- .contiguous()
260
- .unsqueeze(0)
261
- ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
262
- if pretrained_window_size[0] > 0:
263
- relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
264
- relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
265
- elif window_size > 1:
266
- relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
267
- relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
268
- relative_coords_table *= 8 # normalize to -8, 8
269
- relative_coords_table = (
270
- torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
271
- )
272
- # set to same dtype as mlp weight
273
- relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
253
+ relative_coords_table, relative_position_index = self.create_coords_table_and_index()
274
254
  self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
275
-
276
- # get pair-wise relative position index for each token inside the window
277
- coords_h = torch.arange(self.window_size[0])
278
- coords_w = torch.arange(self.window_size[1])
279
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
280
- coords_flatten = torch.flatten(coords, 1)
281
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
282
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
283
- relative_coords[:, :, 0] += self.window_size[0] - 1
284
- relative_coords[:, :, 1] += self.window_size[1] - 1
285
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
286
- relative_position_index = relative_coords.sum(-1)
287
255
  self.register_buffer("relative_position_index", relative_position_index, persistent=False)
288
256
 
289
257
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -359,6 +327,43 @@ class Swin2SRSelfAttention(nn.Module):
359
327
 
360
328
  return outputs
361
329
 
330
+ def create_coords_table_and_index(self):
331
+ # get relative_coords_table
332
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
333
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
334
+ relative_coords_table = (
335
+ torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
336
+ .permute(1, 2, 0)
337
+ .contiguous()
338
+ .unsqueeze(0)
339
+ ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
340
+ if self.pretrained_window_size[0] > 0:
341
+ relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
342
+ relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
343
+ elif self.window_size[0] > 1:
344
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
345
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
346
+ relative_coords_table *= 8 # normalize to -8, 8
347
+ relative_coords_table = (
348
+ torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
349
+ )
350
+ # set to same dtype as mlp weight
351
+ relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
352
+
353
+ # get pair-wise relative position index for each token inside the window
354
+ coords_h = torch.arange(self.window_size[0])
355
+ coords_w = torch.arange(self.window_size[1])
356
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
357
+ coords_flatten = torch.flatten(coords, 1)
358
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
359
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
360
+ relative_coords[:, :, 0] += self.window_size[0] - 1
361
+ relative_coords[:, :, 1] += self.window_size[1] - 1
362
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
363
+ relative_position_index = relative_coords.sum(-1)
364
+
365
+ return relative_coords_table, relative_position_index
366
+
362
367
 
363
368
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swin2SR
364
369
  class Swin2SRSelfOutput(nn.Module):
@@ -702,6 +707,17 @@ class Swin2SRPreTrainedModel(PreTrainedModel):
702
707
  elif isinstance(module, nn.LayerNorm):
703
708
  init.zeros_(module.bias)
704
709
  init.ones_(module.weight)
710
+ elif isinstance(module, Swin2SRSelfAttention):
711
+ init.constant_(module.logit_scale, math.log(10))
712
+ relative_coords_table, relative_position_index = module.create_coords_table_and_index()
713
+ init.copy_(module.relative_coords_table, relative_coords_table)
714
+ init.copy_(module.relative_position_index, relative_position_index)
715
+ elif isinstance(module, Swin2SRModel):
716
+ if module.config.num_channels == 3 and module.config.num_channels_out == 3:
717
+ mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
718
+ else:
719
+ mean = torch.zeros(1, 1, 1, 1)
720
+ init.copy_(module.mean, mean)
705
721
 
706
722
 
707
723
  @auto_docstring