transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import torch.nn as nn
22
22
 
23
23
  from ... import initialization as init
24
24
  from ...cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
25
- from ...configuration_utils import PreTrainedConfig
25
+ from ...configuration_utils import PreTrainedConfig, layer_type_validation
26
26
  from ...generation import GenerationConfig, GenerationMixin, GenerationMode
27
27
  from ...masking_utils import create_bidirectional_mask
28
28
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -34,6 +34,7 @@ from ...modeling_outputs import (
34
34
  SequenceClassifierOutput,
35
35
  TokenClassifierOutput,
36
36
  )
37
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
37
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
39
  from ...processing_utils import Unpack
39
40
  from ...utils import (
@@ -70,9 +71,146 @@ from ..t5gemma.modeling_t5gemma import (
70
71
  logger = logging.get_logger(__name__)
71
72
 
72
73
 
73
- class T5Gemma2TextConfig(Gemma3TextConfig):
74
+ class T5Gemma2TextConfig(Gemma3TextConfig, PreTrainedConfig):
75
+ r"""
76
+ This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
77
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
78
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
79
+ e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
80
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
81
+ documentation from [`PreTrainedConfig`] for more information.
82
+
83
+ Args:
84
+ vocab_size (`int`, *optional*, defaults to 262208):
85
+ Vocabulary size of the T5Gemma2Text model. Defines the number of different tokens that can be represented by the
86
+ `inputs_ids` passed when calling [`T5Gemma2TextModel`]
87
+ hidden_size (`int`, *optional*, defaults to 2304):
88
+ Dimension of the hidden representations.
89
+ intermediate_size (`int`, *optional*, defaults to 9216):
90
+ Dimension of the MLP representations.
91
+ num_hidden_layers (`int`, *optional*, defaults to 26):
92
+ Number of hidden layers in the Transformer decoder.
93
+ num_attention_heads (`int`, *optional*, defaults to 8):
94
+ Number of attention heads for each attention layer in the Transformer decoder.
95
+ num_key_value_heads (`int`, *optional*, defaults to 4):
96
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
97
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
98
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
99
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
100
+ by meanpooling all the original heads within that group. For more details, check out [this
101
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
102
+ `num_attention_heads`.
103
+ head_dim (`int`, *optional*, defaults to 256):
104
+ The attention head dimension.
105
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
106
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
107
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
108
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
109
+ The maximum sequence length that this model might ever be used with.
110
+ initializer_range (`float`, *optional*, defaults to 0.02):
111
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
112
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
113
+ The epsilon used by the rms normalization layers.
114
+ use_cache (`bool`, *optional*, defaults to `True`):
115
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
116
+ relevant if `config.is_decoder=True`.
117
+ pad_token_id (`int`, *optional*, defaults to 0):
118
+ Padding token id.
119
+ eos_token_id (`int`, *optional*, defaults to 1):
120
+ End of stream token id.
121
+ bos_token_id (`int`, *optional*, defaults to 2):
122
+ Beginning of stream token id.
123
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
124
+ Whether to tie weight embeddings
125
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
126
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
127
+ attention_dropout (`float`, *optional*, defaults to 0.0):
128
+ The dropout ratio for the attention probabilities.
129
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
130
+ Scaling factor used on the attention scores
131
+ sliding_window (`int`, *optional*, defaults to 4096):
132
+ In T5Gemma2Text, every other layer uses sliding window attention. This is the size of the sliding window.
133
+ layer_types (`list`, *optional*):
134
+ Attention pattern for each layer.
135
+ final_logit_softcapping (`float`, *optional*):
136
+ Scaling factor when applying tanh softcapping on the logits.
137
+ attn_logit_softcapping (`float`, *optional*):
138
+ Scaling factor when applying tanh softcapping on the attention scores.
139
+ rope_parameters (`RopeParameters`, *optional*):
140
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
141
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
142
+ with longer `max_position_embeddings`.
143
+ """
144
+
74
145
  model_type = "t5gemma2_text"
75
146
 
147
+ def __init__(
148
+ self,
149
+ vocab_size: Optional[int] = 262_208,
150
+ hidden_size: Optional[int] = 2304,
151
+ intermediate_size: Optional[int] = 9216,
152
+ num_hidden_layers: Optional[int] = 26,
153
+ num_attention_heads: Optional[int] = 8,
154
+ num_key_value_heads: Optional[int] = 4,
155
+ head_dim: Optional[int] = 256,
156
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh",
157
+ max_position_embeddings: Optional[int] = 131_072,
158
+ initializer_range: Optional[float] = 0.02,
159
+ rms_norm_eps: Optional[int] = 1e-6,
160
+ use_cache: Optional[bool] = True,
161
+ pad_token_id: Optional[int] = 0,
162
+ eos_token_id: Optional[int] = 1,
163
+ bos_token_id: Optional[int] = 2,
164
+ tie_word_embeddings: Optional[bool] = True,
165
+ attention_bias: Optional[bool] = False,
166
+ attention_dropout: Optional[float] = 0.0,
167
+ query_pre_attn_scalar: Optional[int] = 256,
168
+ sliding_window: Optional[int] = 4096,
169
+ layer_types: Optional[list[str]] = None,
170
+ final_logit_softcapping: Optional[float] = None,
171
+ attn_logit_softcapping: Optional[float] = None,
172
+ rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
173
+ **kwargs,
174
+ ):
175
+ self.vocab_size = vocab_size
176
+ self.max_position_embeddings = max_position_embeddings
177
+ self.hidden_size = hidden_size
178
+ self.intermediate_size = intermediate_size
179
+ self.num_hidden_layers = num_hidden_layers
180
+ self.num_attention_heads = num_attention_heads
181
+ self.head_dim = head_dim
182
+ self.num_key_value_heads = num_key_value_heads
183
+ self.initializer_range = initializer_range
184
+ self.rms_norm_eps = rms_norm_eps
185
+ self.use_cache = use_cache
186
+ self.attention_bias = attention_bias
187
+ self.attention_dropout = attention_dropout
188
+ self.hidden_activation = hidden_activation
189
+ self.query_pre_attn_scalar = query_pre_attn_scalar
190
+ self.sliding_window = sliding_window
191
+ self.final_logit_softcapping = final_logit_softcapping
192
+ self.attn_logit_softcapping = attn_logit_softcapping
193
+ self.layer_types = layer_types
194
+
195
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
196
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
197
+
198
+ if self.layer_types is None:
199
+ self.layer_types = [
200
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
201
+ for i in range(self.num_hidden_layers)
202
+ ]
203
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
204
+
205
+ self.rope_parameters = rope_parameters
206
+ PreTrainedConfig.__init__(
207
+ pad_token_id=pad_token_id,
208
+ bos_token_id=bos_token_id,
209
+ eos_token_id=eos_token_id,
210
+ tie_word_embeddings=tie_word_embeddings,
211
+ **kwargs,
212
+ )
213
+
76
214
 
77
215
  class T5Gemma2EncoderConfig(Gemma3Config):
78
216
  model_type = "t5gemma2_encoder"
@@ -83,9 +221,146 @@ class T5Gemma2EncoderConfig(Gemma3Config):
83
221
  }
84
222
 
85
223
 
86
- class T5Gemma2DecoderConfig(Gemma3TextConfig):
224
+ class T5Gemma2DecoderConfig(Gemma3TextConfig, PreTrainedConfig):
225
+ r"""
226
+ This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
227
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
228
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
229
+ e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
230
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
231
+ documentation from [`PreTrainedConfig`] for more information.
232
+
233
+ Args:
234
+ vocab_size (`int`, *optional*, defaults to 262208):
235
+ Vocabulary size of the T5Gemma2Decoder model. Defines the number of different tokens that can be represented by the
236
+ `inputs_ids` passed when calling [`T5Gemma2DecoderModel`]
237
+ hidden_size (`int`, *optional*, defaults to 2304):
238
+ Dimension of the hidden representations.
239
+ intermediate_size (`int`, *optional*, defaults to 9216):
240
+ Dimension of the MLP representations.
241
+ num_hidden_layers (`int`, *optional*, defaults to 26):
242
+ Number of hidden layers in the Transformer decoder.
243
+ num_attention_heads (`int`, *optional*, defaults to 8):
244
+ Number of attention heads for each attention layer in the Transformer decoder.
245
+ num_key_value_heads (`int`, *optional*, defaults to 4):
246
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
247
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
248
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
249
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
250
+ by meanpooling all the original heads within that group. For more details, check out [this
251
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
252
+ `num_attention_heads`.
253
+ head_dim (`int`, *optional*, defaults to 256):
254
+ The attention head dimension.
255
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
256
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
257
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
258
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
259
+ The maximum sequence length that this model might ever be used with.
260
+ initializer_range (`float`, *optional*, defaults to 0.02):
261
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
262
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
263
+ The epsilon used by the rms normalization layers.
264
+ use_cache (`bool`, *optional*, defaults to `True`):
265
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
266
+ relevant if `config.is_decoder=True`.
267
+ pad_token_id (`int`, *optional*, defaults to 0):
268
+ Padding token id.
269
+ eos_token_id (`int`, *optional*, defaults to 1):
270
+ End of stream token id.
271
+ bos_token_id (`int`, *optional*, defaults to 2):
272
+ Beginning of stream token id.
273
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
274
+ Whether to tie weight embeddings
275
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
276
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
277
+ attention_dropout (`float`, *optional*, defaults to 0.0):
278
+ The dropout ratio for the attention probabilities.
279
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
280
+ Scaling factor used on the attention scores
281
+ sliding_window (`int`, *optional*, defaults to 4096):
282
+ In T5Gemma2Decoder, every other layer uses sliding window attention. This is the size of the sliding window.
283
+ layer_types (`list`, *optional*):
284
+ Attention pattern for each layer.
285
+ final_logit_softcapping (`float`, *optional*):
286
+ Scaling factor when applying tanh softcapping on the logits.
287
+ attn_logit_softcapping (`float`, *optional*):
288
+ Scaling factor when applying tanh softcapping on the attention scores.
289
+ rope_parameters (`RopeParameters`, *optional*):
290
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
291
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
292
+ with longer `max_position_embeddings`.
293
+ """
294
+
87
295
  model_type = "t5gemma2_decoder"
88
296
 
297
+ def __init__(
298
+ self,
299
+ vocab_size: Optional[int] = 262_208,
300
+ hidden_size: Optional[int] = 2304,
301
+ intermediate_size: Optional[int] = 9216,
302
+ num_hidden_layers: Optional[int] = 26,
303
+ num_attention_heads: Optional[int] = 8,
304
+ num_key_value_heads: Optional[int] = 4,
305
+ head_dim: Optional[int] = 256,
306
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh",
307
+ max_position_embeddings: Optional[int] = 131_072,
308
+ initializer_range: Optional[float] = 0.02,
309
+ rms_norm_eps: Optional[int] = 1e-6,
310
+ use_cache: Optional[bool] = True,
311
+ pad_token_id: Optional[int] = 0,
312
+ eos_token_id: Optional[int] = 1,
313
+ bos_token_id: Optional[int] = 2,
314
+ tie_word_embeddings: Optional[bool] = True,
315
+ attention_bias: Optional[bool] = False,
316
+ attention_dropout: Optional[float] = 0.0,
317
+ query_pre_attn_scalar: Optional[int] = 256,
318
+ sliding_window: Optional[int] = 4096,
319
+ layer_types: Optional[list[str]] = None,
320
+ final_logit_softcapping: Optional[float] = None,
321
+ attn_logit_softcapping: Optional[float] = None,
322
+ rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
323
+ **kwargs,
324
+ ):
325
+ self.vocab_size = vocab_size
326
+ self.max_position_embeddings = max_position_embeddings
327
+ self.hidden_size = hidden_size
328
+ self.intermediate_size = intermediate_size
329
+ self.num_hidden_layers = num_hidden_layers
330
+ self.num_attention_heads = num_attention_heads
331
+ self.head_dim = head_dim
332
+ self.num_key_value_heads = num_key_value_heads
333
+ self.initializer_range = initializer_range
334
+ self.rms_norm_eps = rms_norm_eps
335
+ self.use_cache = use_cache
336
+ self.attention_bias = attention_bias
337
+ self.attention_dropout = attention_dropout
338
+ self.hidden_activation = hidden_activation
339
+ self.query_pre_attn_scalar = query_pre_attn_scalar
340
+ self.sliding_window = sliding_window
341
+ self.final_logit_softcapping = final_logit_softcapping
342
+ self.attn_logit_softcapping = attn_logit_softcapping
343
+ self.layer_types = layer_types
344
+
345
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
346
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
347
+
348
+ if self.layer_types is None:
349
+ self.layer_types = [
350
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
351
+ for i in range(self.num_hidden_layers)
352
+ ]
353
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
354
+
355
+ self.rope_parameters = rope_parameters
356
+ PreTrainedConfig.__init__(
357
+ pad_token_id=pad_token_id,
358
+ bos_token_id=bos_token_id,
359
+ eos_token_id=eos_token_id,
360
+ tie_word_embeddings=tie_word_embeddings,
361
+ **kwargs,
362
+ )
363
+
89
364
 
90
365
  class T5Gemma2Config(PreTrainedConfig):
91
366
  r"""
@@ -257,6 +532,7 @@ class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding):
257
532
  class T5Gemma2SelfAttention(Gemma3Attention):
258
533
  def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
259
534
  super().__init__(config, layer_idx)
535
+ self.is_causal = False # Only used by the encoder
260
536
 
261
537
 
262
538
  class T5Gemma2MergedAttention(Gemma3Attention):
@@ -264,6 +540,7 @@ class T5Gemma2MergedAttention(Gemma3Attention):
264
540
 
265
541
  def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
266
542
  super().__init__(config, layer_idx)
543
+ self.is_causal = False # Fused causal and encoder mask
267
544
 
268
545
  def forward(
269
546
  self,
@@ -342,7 +619,6 @@ class T5Gemma2MergedAttention(Gemma3Attention):
342
619
  merged_attention_mask,
343
620
  dropout=self.attention_dropout if self.training else 0.0,
344
621
  scaling=self.scaling,
345
- is_causal=False,
346
622
  **kwargs,
347
623
  )
348
624
 
@@ -498,6 +774,7 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
498
774
  init.zeros_(module.mm_input_projection_weight)
499
775
  elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
500
776
  init.zeros_(module.eoi_embedding)
777
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
501
778
  elif isinstance(module, T5Gemma2ClassificationHead):
502
779
  scale = module.out_proj.weight.shape[0] ** -0.5
503
780
  init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
@@ -506,6 +783,14 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
506
783
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
507
784
  elif "RMSNorm" in module.__class__.__name__:
508
785
  init.zeros_(module.weight)
786
+ elif isinstance(module, T5Gemma2RotaryEmbedding):
787
+ for layer_type in module.layer_types:
788
+ rope_init_fn = module.compute_default_rope_parameters
789
+ if module.rope_type[layer_type] != "default":
790
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
791
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
792
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
793
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
509
794
 
510
795
  def prepare_decoder_input_ids_from_labels(self, input_ids):
511
796
  """
@@ -37,7 +37,7 @@ class TableTransformerConfig(PreTrainedConfig):
37
37
  use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
38
  Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
39
  API.
40
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
40
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
41
41
  The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
42
  case it will default to `ResNetConfig()`.
43
43
  num_channels (`int`, *optional*, defaults to 3):
@@ -702,7 +702,7 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
702
702
  if isinstance(module, TableTransformerLearnedPositionEmbedding):
703
703
  init.uniform_(module.row_embeddings.weight)
704
704
  init.uniform_(module.column_embeddings.weight)
705
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
705
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
706
706
  init.normal_(module.weight, mean=0.0, std=std)
707
707
  if module.bias is not None:
708
708
  init.zeros_(module.bias)
@@ -137,7 +137,6 @@ class TextNetImageProcessorFast(BaseImageProcessorFast):
137
137
  processed_images_grouped[shape] = stacked_images
138
138
 
139
139
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
140
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
141
140
 
142
141
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
143
142
 
@@ -144,6 +144,7 @@ class TimesFmPositionalEmbedding(nn.Module):
144
144
  super().__init__()
145
145
  min_timescale = config.min_timescale
146
146
  max_timescale = config.max_timescale
147
+ self.min_timescale, self.max_timescale = min_timescale, max_timescale
147
148
  self.embedding_dims = config.hidden_size
148
149
 
149
150
  num_timescales = self.embedding_dims // 2
@@ -313,6 +314,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
313
314
  if isinstance(module, TimesFmAttention):
314
315
  # Initialize scaling parameter
315
316
  init.ones_(module.scaling)
317
+ elif isinstance(module, TimesFmPositionalEmbedding):
318
+ num_timescales = module.embedding_dims // 2
319
+ max_timescale, min_timescale = module.max_timescale, module.min_timescale
320
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
321
+ num_timescales - 1, 1
322
+ )
323
+ init.copy_(
324
+ module.inv_timescales,
325
+ min_timescale
326
+ * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
327
+ )
316
328
 
317
329
 
318
330
  @auto_docstring
@@ -123,6 +123,7 @@ class TimesFmPositionalEmbedding(nn.Module):
123
123
  super().__init__()
124
124
  min_timescale = config.min_timescale
125
125
  max_timescale = config.max_timescale
126
+ self.min_timescale, self.max_timescale = min_timescale, max_timescale
126
127
  self.embedding_dims = config.hidden_size
127
128
 
128
129
  num_timescales = self.embedding_dims // 2
@@ -269,6 +270,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
269
270
  if isinstance(module, TimesFmAttention):
270
271
  # Initialize scaling parameter
271
272
  init.ones_(module.scaling)
273
+ elif isinstance(module, TimesFmPositionalEmbedding):
274
+ num_timescales = module.embedding_dims // 2
275
+ max_timescale, min_timescale = module.max_timescale, module.min_timescale
276
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
277
+ num_timescales - 1, 1
278
+ )
279
+ init.copy_(
280
+ module.inv_timescales,
281
+ min_timescale
282
+ * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
283
+ )
272
284
 
273
285
 
274
286
  @auto_docstring
@@ -16,10 +16,12 @@
16
16
  from typing import Optional, Union
17
17
 
18
18
  import torch
19
+ from torch import Tensor, nn
19
20
 
21
+ from ... import initialization as init
20
22
  from ...modeling_outputs import BackboneOutput
21
23
  from ...modeling_utils import PreTrainedModel
22
- from ...utils import is_timm_available, is_torch_available, requires_backends
24
+ from ...utils import is_timm_available, requires_backends
23
25
  from ...utils.backbone_utils import BackboneMixin
24
26
  from .configuration_timm_backbone import TimmBackboneConfig
25
27
 
@@ -28,10 +30,6 @@ if is_timm_available():
28
30
  import timm
29
31
 
30
32
 
31
- if is_torch_available():
32
- from torch import Tensor
33
-
34
-
35
33
  class TimmBackbone(PreTrainedModel, BackboneMixin):
36
34
  """
37
35
  Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
@@ -84,10 +82,11 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
84
82
  self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
85
83
  super()._init_backbone(config)
86
84
 
85
+ self.post_init()
86
+
87
87
  @classmethod
88
88
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
89
89
  requires_backends(cls, ["vision", "timm"])
90
- from ...models.timm_backbone import TimmBackboneConfig
91
90
 
92
91
  config = kwargs.pop("config", TimmBackboneConfig())
93
92
 
@@ -116,9 +115,14 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
116
115
 
117
116
  @torch.no_grad()
118
117
  def _init_weights(self, module):
119
- """
120
- Empty init weights function to ensure compatibility of the class in the library.
121
- """
118
+ """We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we
119
+ assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)"""
120
+ if hasattr(module, "init_non_persistent_buffers"):
121
+ module.init_non_persistent_buffers()
122
+ elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
123
+ init.zeros_(module.running_mean)
124
+ init.ones_(module.running_var)
125
+ init.zeros_(module.num_batches_tracked)
122
126
 
123
127
  def forward(
124
128
  self,
@@ -81,6 +81,9 @@ class TimmWrapperConfig(PreTrainedConfig):
81
81
 
82
82
  @classmethod
83
83
  def from_dict(cls, config_dict: dict[str, Any], **kwargs):
84
+ # Create a copy to avoid mutating the original dict
85
+ config_dict = config_dict.copy()
86
+
84
87
  label_names = config_dict.get("label_names")
85
88
  is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
86
89
 
@@ -84,16 +84,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
84
84
  main_input_name = "pixel_values"
85
85
  input_modalities = ("image",)
86
86
  config: TimmWrapperConfig
87
- _no_split_modules = []
87
+ # add WA here as `timm` does not support model parallelism
88
+ _no_split_modules = ["TimmWrapperModel"]
88
89
  model_tags = ["timm"]
89
90
 
90
91
  # used in Trainer to avoid passing `loss_kwargs` to model forward
91
92
  accepts_loss_kwargs = False
92
93
 
93
- def __init__(self, *args, **kwargs):
94
- requires_backends(self, ["vision", "timm"])
95
- super().__init__(*args, **kwargs)
96
-
97
94
  def post_init(self):
98
95
  self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
99
96
  super().post_init()
@@ -113,10 +110,17 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
113
110
  Since model architectures may vary, we assume only the classifier requires
114
111
  initialization, while all other weights should be loaded from the checkpoint.
115
112
  """
116
- if isinstance(module, (nn.Linear)):
113
+ if isinstance(module, nn.Linear):
117
114
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
115
  if module.bias is not None:
119
116
  init.zeros_(module.bias)
117
+ # Also, reinit all non-persistemt buffers if any!
118
+ if hasattr(module, "init_non_persistent_buffers"):
119
+ module.init_non_persistent_buffers()
120
+ elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
121
+ init.zeros_(module.running_mean)
122
+ init.ones_(module.running_var)
123
+ init.zeros_(module.num_batches_tracked)
120
124
 
121
125
  def _timm_model_supports_gradient_checkpointing(self):
122
126
  """
@@ -136,6 +140,13 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
136
140
  def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
137
141
  self.timm_model.set_grad_checkpointing(enable)
138
142
 
143
+ def get_input_embeddings(self):
144
+ # TIMM backbones operate directly on images and do not expose token embeddings.
145
+ return None
146
+
147
+ def set_input_embeddings(self, value):
148
+ raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.")
149
+
139
150
 
140
151
  class TimmWrapperModel(TimmWrapperPreTrainedModel):
141
152
  """
@@ -143,6 +154,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
143
154
  """
144
155
 
145
156
  def __init__(self, config: TimmWrapperConfig):
157
+ requires_backends(self, ["vision", "timm"])
146
158
  super().__init__(config)
147
159
  # using num_classes=0 to avoid creating classification head
148
160
  extra_init_kwargs = config.model_args or {}
@@ -150,13 +162,6 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
150
162
  self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
151
163
  self.post_init()
152
164
 
153
- def get_input_embeddings(self):
154
- # Vision backbones from timm do not expose token embeddings, so there is nothing to return.
155
- return None
156
-
157
- def set_input_embeddings(self, value):
158
- raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")
159
-
160
165
  @auto_docstring
161
166
  def forward(
162
167
  self,
@@ -265,6 +270,7 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
265
270
  """
266
271
 
267
272
  def __init__(self, config: TimmWrapperConfig):
273
+ requires_backends(self, ["vision", "timm"])
268
274
  super().__init__(config)
269
275
 
270
276
  if config.num_labels == 0:
@@ -89,7 +89,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
89
89
  self.embedding_dim = embedding_dim
90
90
  self.padding_idx = padding_idx
91
91
  self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
92
- self.register_buffer("_float_tensor", torch.FloatTensor(1))
93
92
 
94
93
  @staticmethod
95
94
  def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
@@ -123,7 +122,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
123
122
  if self.weights is None or max_pos > self.weights.size(0):
124
123
  # recompute/expand embeddings if needed
125
124
  self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
126
- self.weights = self.weights.to(self._float_tensor)
127
125
 
128
126
  x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
129
127
 
@@ -636,6 +634,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
636
634
  def __init__(self, config):
637
635
  super().__init__(config)
638
636
  self.decoder = TrOCRDecoder(config)
637
+ self.post_init()
639
638
 
640
639
  def forward(self, *args, **kwargs):
641
640
  return self.decoder(*args, **kwargs)
@@ -35,7 +35,7 @@ class TvpConfig(PreTrainedConfig):
35
35
 
36
36
 
37
37
  Args:
38
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
38
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
39
39
  The configuration of the backbone model.
40
40
  backbone (`str`, *optional*):
41
41
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -68,6 +68,8 @@ class TvpConfig(PreTrainedConfig):
68
68
  vocab_size (`int`, *optional*, defaults to 30522):
69
69
  Vocabulary size of the Tvp text model. Defines the number of different tokens that can be represented by
70
70
  the `inputs_ids` passed when calling [`TvpModel`].
71
+ type_vocab_size (`int`, *optional*, defaults to 2):
72
+ The vocabulary size of the `token_type_ids` passed when calling [`TvpModel`].
71
73
  hidden_size (`int`, *optional*, defaults to 768):
72
74
  Dimensionality of the encoder layers.
73
75
  intermediate_size (`int`, *optional*, defaults to 3072):
@@ -114,6 +116,7 @@ class TvpConfig(PreTrainedConfig):
114
116
  max_img_size=448,
115
117
  num_frames=48,
116
118
  vocab_size=30522,
119
+ type_vocab_size=2,
117
120
  hidden_size=768,
118
121
  intermediate_size=3072,
119
122
  num_hidden_layers=12,
@@ -157,6 +160,7 @@ class TvpConfig(PreTrainedConfig):
157
160
  self.max_img_size = max_img_size
158
161
  self.num_frames = num_frames
159
162
  self.vocab_size = vocab_size
163
+ self.type_vocab_size = type_vocab_size
160
164
  self.hidden_size = hidden_size
161
165
  self.intermediate_size = intermediate_size
162
166
  self.num_hidden_layers = num_hidden_layers