transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import torch.nn as nn
22
22
 
23
23
  from ... import initialization as init
24
24
  from ...cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
25
- from ...configuration_utils import PreTrainedConfig
25
+ from ...configuration_utils import PreTrainedConfig, layer_type_validation
26
26
  from ...generation import GenerationConfig, GenerationMixin, GenerationMode
27
27
  from ...masking_utils import create_bidirectional_mask
28
28
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -34,6 +34,7 @@ from ...modeling_outputs import (
34
34
  SequenceClassifierOutput,
35
35
  TokenClassifierOutput,
36
36
  )
37
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
37
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
39
  from ...processing_utils import Unpack
39
40
  from ...utils import (
@@ -70,9 +71,146 @@ from ..t5gemma.modeling_t5gemma import (
70
71
  logger = logging.get_logger(__name__)
71
72
 
72
73
 
73
- class T5Gemma2TextConfig(Gemma3TextConfig):
74
+ class T5Gemma2TextConfig(Gemma3TextConfig, PreTrainedConfig):
75
+ r"""
76
+ This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
77
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
78
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
79
+ e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
80
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
81
+ documentation from [`PreTrainedConfig`] for more information.
82
+
83
+ Args:
84
+ vocab_size (`int`, *optional*, defaults to 262208):
85
+ Vocabulary size of the T5Gemma2Text model. Defines the number of different tokens that can be represented by the
86
+ `inputs_ids` passed when calling [`T5Gemma2TextModel`]
87
+ hidden_size (`int`, *optional*, defaults to 2304):
88
+ Dimension of the hidden representations.
89
+ intermediate_size (`int`, *optional*, defaults to 9216):
90
+ Dimension of the MLP representations.
91
+ num_hidden_layers (`int`, *optional*, defaults to 26):
92
+ Number of hidden layers in the Transformer decoder.
93
+ num_attention_heads (`int`, *optional*, defaults to 8):
94
+ Number of attention heads for each attention layer in the Transformer decoder.
95
+ num_key_value_heads (`int`, *optional*, defaults to 4):
96
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
97
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
98
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
99
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
100
+ by meanpooling all the original heads within that group. For more details, check out [this
101
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
102
+ `num_attention_heads`.
103
+ head_dim (`int`, *optional*, defaults to 256):
104
+ The attention head dimension.
105
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
106
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
107
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
108
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
109
+ The maximum sequence length that this model might ever be used with.
110
+ initializer_range (`float`, *optional*, defaults to 0.02):
111
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
112
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
113
+ The epsilon used by the rms normalization layers.
114
+ use_cache (`bool`, *optional*, defaults to `True`):
115
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
116
+ relevant if `config.is_decoder=True`.
117
+ pad_token_id (`int`, *optional*, defaults to 0):
118
+ Padding token id.
119
+ eos_token_id (`int`, *optional*, defaults to 1):
120
+ End of stream token id.
121
+ bos_token_id (`int`, *optional*, defaults to 2):
122
+ Beginning of stream token id.
123
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
124
+ Whether to tie weight embeddings
125
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
126
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
127
+ attention_dropout (`float`, *optional*, defaults to 0.0):
128
+ The dropout ratio for the attention probabilities.
129
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
130
+ Scaling factor used on the attention scores
131
+ sliding_window (`int`, *optional*, defaults to 4096):
132
+ In T5Gemma2Text, every other layer uses sliding window attention. This is the size of the sliding window.
133
+ layer_types (`list`, *optional*):
134
+ Attention pattern for each layer.
135
+ final_logit_softcapping (`float`, *optional*):
136
+ Scaling factor when applying tanh softcapping on the logits.
137
+ attn_logit_softcapping (`float`, *optional*):
138
+ Scaling factor when applying tanh softcapping on the attention scores.
139
+ rope_parameters (`RopeParameters`, *optional*):
140
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
141
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
142
+ with longer `max_position_embeddings`.
143
+ """
144
+
74
145
  model_type = "t5gemma2_text"
75
146
 
147
+ def __init__(
148
+ self,
149
+ vocab_size: Optional[int] = 262_208,
150
+ hidden_size: Optional[int] = 2304,
151
+ intermediate_size: Optional[int] = 9216,
152
+ num_hidden_layers: Optional[int] = 26,
153
+ num_attention_heads: Optional[int] = 8,
154
+ num_key_value_heads: Optional[int] = 4,
155
+ head_dim: Optional[int] = 256,
156
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh",
157
+ max_position_embeddings: Optional[int] = 131_072,
158
+ initializer_range: Optional[float] = 0.02,
159
+ rms_norm_eps: Optional[int] = 1e-6,
160
+ use_cache: Optional[bool] = True,
161
+ pad_token_id: Optional[int] = 0,
162
+ eos_token_id: Optional[int] = 1,
163
+ bos_token_id: Optional[int] = 2,
164
+ tie_word_embeddings: Optional[bool] = True,
165
+ attention_bias: Optional[bool] = False,
166
+ attention_dropout: Optional[float] = 0.0,
167
+ query_pre_attn_scalar: Optional[int] = 256,
168
+ sliding_window: Optional[int] = 4096,
169
+ layer_types: Optional[list[str]] = None,
170
+ final_logit_softcapping: Optional[float] = None,
171
+ attn_logit_softcapping: Optional[float] = None,
172
+ rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
173
+ **kwargs,
174
+ ):
175
+ self.vocab_size = vocab_size
176
+ self.max_position_embeddings = max_position_embeddings
177
+ self.hidden_size = hidden_size
178
+ self.intermediate_size = intermediate_size
179
+ self.num_hidden_layers = num_hidden_layers
180
+ self.num_attention_heads = num_attention_heads
181
+ self.head_dim = head_dim
182
+ self.num_key_value_heads = num_key_value_heads
183
+ self.initializer_range = initializer_range
184
+ self.rms_norm_eps = rms_norm_eps
185
+ self.use_cache = use_cache
186
+ self.attention_bias = attention_bias
187
+ self.attention_dropout = attention_dropout
188
+ self.hidden_activation = hidden_activation
189
+ self.query_pre_attn_scalar = query_pre_attn_scalar
190
+ self.sliding_window = sliding_window
191
+ self.final_logit_softcapping = final_logit_softcapping
192
+ self.attn_logit_softcapping = attn_logit_softcapping
193
+ self.layer_types = layer_types
194
+
195
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
196
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
197
+
198
+ if self.layer_types is None:
199
+ self.layer_types = [
200
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
201
+ for i in range(self.num_hidden_layers)
202
+ ]
203
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
204
+
205
+ self.rope_parameters = rope_parameters
206
+ PreTrainedConfig.__init__(
207
+ pad_token_id=pad_token_id,
208
+ bos_token_id=bos_token_id,
209
+ eos_token_id=eos_token_id,
210
+ tie_word_embeddings=tie_word_embeddings,
211
+ **kwargs,
212
+ )
213
+
76
214
 
77
215
  class T5Gemma2EncoderConfig(Gemma3Config):
78
216
  model_type = "t5gemma2_encoder"
@@ -83,9 +221,146 @@ class T5Gemma2EncoderConfig(Gemma3Config):
83
221
  }
84
222
 
85
223
 
86
- class T5Gemma2DecoderConfig(Gemma3TextConfig):
224
+ class T5Gemma2DecoderConfig(Gemma3TextConfig, PreTrainedConfig):
225
+ r"""
226
+ This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
227
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
228
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
229
+ e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
230
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
231
+ documentation from [`PreTrainedConfig`] for more information.
232
+
233
+ Args:
234
+ vocab_size (`int`, *optional*, defaults to 262208):
235
+ Vocabulary size of the T5Gemma2Decoder model. Defines the number of different tokens that can be represented by the
236
+ `inputs_ids` passed when calling [`T5Gemma2DecoderModel`]
237
+ hidden_size (`int`, *optional*, defaults to 2304):
238
+ Dimension of the hidden representations.
239
+ intermediate_size (`int`, *optional*, defaults to 9216):
240
+ Dimension of the MLP representations.
241
+ num_hidden_layers (`int`, *optional*, defaults to 26):
242
+ Number of hidden layers in the Transformer decoder.
243
+ num_attention_heads (`int`, *optional*, defaults to 8):
244
+ Number of attention heads for each attention layer in the Transformer decoder.
245
+ num_key_value_heads (`int`, *optional*, defaults to 4):
246
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
247
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
248
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
249
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
250
+ by meanpooling all the original heads within that group. For more details, check out [this
251
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
252
+ `num_attention_heads`.
253
+ head_dim (`int`, *optional*, defaults to 256):
254
+ The attention head dimension.
255
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
256
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
257
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
258
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
259
+ The maximum sequence length that this model might ever be used with.
260
+ initializer_range (`float`, *optional*, defaults to 0.02):
261
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
262
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
263
+ The epsilon used by the rms normalization layers.
264
+ use_cache (`bool`, *optional*, defaults to `True`):
265
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
266
+ relevant if `config.is_decoder=True`.
267
+ pad_token_id (`int`, *optional*, defaults to 0):
268
+ Padding token id.
269
+ eos_token_id (`int`, *optional*, defaults to 1):
270
+ End of stream token id.
271
+ bos_token_id (`int`, *optional*, defaults to 2):
272
+ Beginning of stream token id.
273
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
274
+ Whether to tie weight embeddings
275
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
276
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
277
+ attention_dropout (`float`, *optional*, defaults to 0.0):
278
+ The dropout ratio for the attention probabilities.
279
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
280
+ Scaling factor used on the attention scores
281
+ sliding_window (`int`, *optional*, defaults to 4096):
282
+ In T5Gemma2Decoder, every other layer uses sliding window attention. This is the size of the sliding window.
283
+ layer_types (`list`, *optional*):
284
+ Attention pattern for each layer.
285
+ final_logit_softcapping (`float`, *optional*):
286
+ Scaling factor when applying tanh softcapping on the logits.
287
+ attn_logit_softcapping (`float`, *optional*):
288
+ Scaling factor when applying tanh softcapping on the attention scores.
289
+ rope_parameters (`RopeParameters`, *optional*):
290
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
291
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
292
+ with longer `max_position_embeddings`.
293
+ """
294
+
87
295
  model_type = "t5gemma2_decoder"
88
296
 
297
+ def __init__(
298
+ self,
299
+ vocab_size: Optional[int] = 262_208,
300
+ hidden_size: Optional[int] = 2304,
301
+ intermediate_size: Optional[int] = 9216,
302
+ num_hidden_layers: Optional[int] = 26,
303
+ num_attention_heads: Optional[int] = 8,
304
+ num_key_value_heads: Optional[int] = 4,
305
+ head_dim: Optional[int] = 256,
306
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh",
307
+ max_position_embeddings: Optional[int] = 131_072,
308
+ initializer_range: Optional[float] = 0.02,
309
+ rms_norm_eps: Optional[int] = 1e-6,
310
+ use_cache: Optional[bool] = True,
311
+ pad_token_id: Optional[int] = 0,
312
+ eos_token_id: Optional[int] = 1,
313
+ bos_token_id: Optional[int] = 2,
314
+ tie_word_embeddings: Optional[bool] = True,
315
+ attention_bias: Optional[bool] = False,
316
+ attention_dropout: Optional[float] = 0.0,
317
+ query_pre_attn_scalar: Optional[int] = 256,
318
+ sliding_window: Optional[int] = 4096,
319
+ layer_types: Optional[list[str]] = None,
320
+ final_logit_softcapping: Optional[float] = None,
321
+ attn_logit_softcapping: Optional[float] = None,
322
+ rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
323
+ **kwargs,
324
+ ):
325
+ self.vocab_size = vocab_size
326
+ self.max_position_embeddings = max_position_embeddings
327
+ self.hidden_size = hidden_size
328
+ self.intermediate_size = intermediate_size
329
+ self.num_hidden_layers = num_hidden_layers
330
+ self.num_attention_heads = num_attention_heads
331
+ self.head_dim = head_dim
332
+ self.num_key_value_heads = num_key_value_heads
333
+ self.initializer_range = initializer_range
334
+ self.rms_norm_eps = rms_norm_eps
335
+ self.use_cache = use_cache
336
+ self.attention_bias = attention_bias
337
+ self.attention_dropout = attention_dropout
338
+ self.hidden_activation = hidden_activation
339
+ self.query_pre_attn_scalar = query_pre_attn_scalar
340
+ self.sliding_window = sliding_window
341
+ self.final_logit_softcapping = final_logit_softcapping
342
+ self.attn_logit_softcapping = attn_logit_softcapping
343
+ self.layer_types = layer_types
344
+
345
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
346
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
347
+
348
+ if self.layer_types is None:
349
+ self.layer_types = [
350
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
351
+ for i in range(self.num_hidden_layers)
352
+ ]
353
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
354
+
355
+ self.rope_parameters = rope_parameters
356
+ PreTrainedConfig.__init__(
357
+ pad_token_id=pad_token_id,
358
+ bos_token_id=bos_token_id,
359
+ eos_token_id=eos_token_id,
360
+ tie_word_embeddings=tie_word_embeddings,
361
+ **kwargs,
362
+ )
363
+
89
364
 
90
365
  class T5Gemma2Config(PreTrainedConfig):
91
366
  r"""
@@ -257,6 +532,7 @@ class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding):
257
532
  class T5Gemma2SelfAttention(Gemma3Attention):
258
533
  def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
259
534
  super().__init__(config, layer_idx)
535
+ self.is_causal = False # Only used by the encoder
260
536
 
261
537
 
262
538
  class T5Gemma2MergedAttention(Gemma3Attention):
@@ -264,6 +540,7 @@ class T5Gemma2MergedAttention(Gemma3Attention):
264
540
 
265
541
  def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
266
542
  super().__init__(config, layer_idx)
543
+ self.is_causal = False # Fused causal and encoder mask
267
544
 
268
545
  def forward(
269
546
  self,
@@ -342,7 +619,6 @@ class T5Gemma2MergedAttention(Gemma3Attention):
342
619
  merged_attention_mask,
343
620
  dropout=self.attention_dropout if self.training else 0.0,
344
621
  scaling=self.scaling,
345
- is_causal=False,
346
622
  **kwargs,
347
623
  )
348
624
 
@@ -498,6 +774,7 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
498
774
  init.zeros_(module.mm_input_projection_weight)
499
775
  elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
500
776
  init.zeros_(module.eoi_embedding)
777
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
501
778
  elif isinstance(module, T5Gemma2ClassificationHead):
502
779
  scale = module.out_proj.weight.shape[0] ** -0.5
503
780
  init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
@@ -506,6 +783,14 @@ class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
506
783
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
507
784
  elif "RMSNorm" in module.__class__.__name__:
508
785
  init.zeros_(module.weight)
786
+ elif isinstance(module, T5Gemma2RotaryEmbedding):
787
+ for layer_type in module.layer_types:
788
+ rope_init_fn = module.compute_default_rope_parameters
789
+ if module.rope_type[layer_type] != "default":
790
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
791
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
792
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
793
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
509
794
 
510
795
  def prepare_decoder_input_ids_from_labels(self, input_ids):
511
796
  """
@@ -37,7 +37,7 @@ class TableTransformerConfig(PreTrainedConfig):
37
37
  use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
38
  Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
39
  API.
40
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
40
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
41
41
  The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
42
  case it will default to `ResNetConfig()`.
43
43
  num_channels (`int`, *optional*, defaults to 3):
@@ -702,7 +702,7 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
702
702
  if isinstance(module, TableTransformerLearnedPositionEmbedding):
703
703
  init.uniform_(module.row_embeddings.weight)
704
704
  init.uniform_(module.column_embeddings.weight)
705
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
705
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
706
706
  init.normal_(module.weight, mean=0.0, std=std)
707
707
  if module.bias is not None:
708
708
  init.zeros_(module.bias)
@@ -749,6 +749,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
749
749
  output_attentions=None,
750
750
  output_hidden_states=None,
751
751
  return_dict=None,
752
+ **kwargs,
752
753
  ):
753
754
  r"""
754
755
  Args:
@@ -869,6 +870,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
869
870
  output_attentions=None,
870
871
  output_hidden_states=None,
871
872
  return_dict=None,
873
+ **kwargs,
872
874
  ):
873
875
  r"""
874
876
  Args:
@@ -1043,6 +1045,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
1043
1045
  output_attentions: Optional[bool] = None,
1044
1046
  output_hidden_states: Optional[bool] = None,
1045
1047
  return_dict: Optional[bool] = None,
1048
+ **kwargs,
1046
1049
  ) -> Union[tuple[torch.FloatTensor], TableTransformerModelOutput]:
1047
1050
  r"""
1048
1051
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1202,6 +1205,7 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
1202
1205
  output_attentions: Optional[bool] = None,
1203
1206
  output_hidden_states: Optional[bool] = None,
1204
1207
  return_dict: Optional[bool] = None,
1208
+ **kwargs,
1205
1209
  ) -> Union[tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]:
1206
1210
  r"""
1207
1211
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -563,6 +563,7 @@ class TapasModel(TapasPreTrainedModel):
563
563
  output_attentions: Optional[bool] = None,
564
564
  output_hidden_states: Optional[bool] = None,
565
565
  return_dict: Optional[bool] = None,
566
+ **kwargs,
566
567
  ) -> Union[tuple, BaseModelOutputWithPooling]:
567
568
  r"""
568
569
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -843,6 +844,7 @@ class TapasForQuestionAnswering(TapasPreTrainedModel):
843
844
  output_attentions: Optional[bool] = None,
844
845
  output_hidden_states: Optional[bool] = None,
845
846
  return_dict: Optional[bool] = None,
847
+ **kwargs,
846
848
  ) -> Union[tuple, TableQuestionAnsweringOutput]:
847
849
  r"""
848
850
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -1164,6 +1166,7 @@ class TapasForSequenceClassification(TapasPreTrainedModel):
1164
1166
  output_attentions: Optional[bool] = None,
1165
1167
  output_hidden_states: Optional[bool] = None,
1166
1168
  return_dict: Optional[bool] = None,
1169
+ **kwargs,
1167
1170
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1168
1171
  r"""
1169
1172
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -137,7 +137,6 @@ class TextNetImageProcessorFast(BaseImageProcessorFast):
137
137
  processed_images_grouped[shape] = stacked_images
138
138
 
139
139
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
140
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
141
140
 
142
141
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
143
142
 
@@ -233,7 +233,11 @@ class TextNetModel(TextNetPreTrainedModel):
233
233
 
234
234
  @auto_docstring
235
235
  def forward(
236
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
236
+ self,
237
+ pixel_values: Tensor,
238
+ output_hidden_states: Optional[bool] = None,
239
+ return_dict: Optional[bool] = None,
240
+ **kwargs,
237
241
  ) -> Union[tuple[Any, list[Any]], tuple[Any], BaseModelOutputWithPoolingAndNoAttention]:
238
242
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
239
243
  output_hidden_states = (
@@ -288,6 +292,7 @@ class TextNetForImageClassification(TextNetPreTrainedModel):
288
292
  labels: Optional[torch.LongTensor] = None,
289
293
  output_hidden_states: Optional[bool] = None,
290
294
  return_dict: Optional[bool] = None,
295
+ **kwargs,
291
296
  ) -> ImageClassifierOutputWithNoAttention:
292
297
  r"""
293
298
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -353,7 +358,11 @@ class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin):
353
358
 
354
359
  @auto_docstring
355
360
  def forward(
356
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
361
+ self,
362
+ pixel_values: Tensor,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ **kwargs,
357
366
  ) -> Union[tuple[tuple], BackboneOutput]:
358
367
  r"""
359
368
  Examples:
@@ -658,6 +658,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
658
658
  output_attentions: Optional[bool] = None,
659
659
  output_hidden_states: Optional[bool] = None,
660
660
  return_dict: Optional[bool] = None,
661
+ **kwargs,
661
662
  ) -> Union[tuple, BaseModelOutput]:
662
663
  r"""
663
664
  Args:
@@ -777,6 +778,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
777
778
  output_hidden_states: Optional[bool] = None,
778
779
  return_dict: Optional[bool] = None,
779
780
  cache_position: Optional[torch.LongTensor] = None,
781
+ **kwargs,
780
782
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
781
783
  r"""
782
784
  Args:
@@ -1075,6 +1077,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
1075
1077
  use_cache: Optional[bool] = None,
1076
1078
  return_dict: Optional[bool] = None,
1077
1079
  cache_position: Optional[torch.LongTensor] = None,
1080
+ **kwargs,
1078
1081
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1079
1082
  r"""
1080
1083
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -1320,6 +1323,7 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
1320
1323
  use_cache: Optional[bool] = None,
1321
1324
  return_dict: Optional[bool] = None,
1322
1325
  cache_position: Optional[torch.LongTensor] = None,
1326
+ **kwargs,
1323
1327
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1324
1328
  r"""
1325
1329
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -144,6 +144,7 @@ class TimesFmPositionalEmbedding(nn.Module):
144
144
  super().__init__()
145
145
  min_timescale = config.min_timescale
146
146
  max_timescale = config.max_timescale
147
+ self.min_timescale, self.max_timescale = min_timescale, max_timescale
147
148
  self.embedding_dims = config.hidden_size
148
149
 
149
150
  num_timescales = self.embedding_dims // 2
@@ -313,6 +314,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
313
314
  if isinstance(module, TimesFmAttention):
314
315
  # Initialize scaling parameter
315
316
  init.ones_(module.scaling)
317
+ elif isinstance(module, TimesFmPositionalEmbedding):
318
+ num_timescales = module.embedding_dims // 2
319
+ max_timescale, min_timescale = module.max_timescale, module.min_timescale
320
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
321
+ num_timescales - 1, 1
322
+ )
323
+ init.copy_(
324
+ module.inv_timescales,
325
+ min_timescale
326
+ * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
327
+ )
316
328
 
317
329
 
318
330
  @auto_docstring
@@ -361,6 +373,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
361
373
  freq: torch.Tensor,
362
374
  output_attentions: bool = False,
363
375
  output_hidden_states: bool = False,
376
+ **kwargs,
364
377
  ) -> TimesFmOutput:
365
378
  r"""
366
379
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -668,6 +681,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
668
681
  truncate_negative: bool = False,
669
682
  output_attentions: Optional[bool] = None,
670
683
  output_hidden_states: Optional[bool] = None,
684
+ **kwargs,
671
685
  ) -> TimesFmOutputForPrediction:
672
686
  r"""
673
687
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -123,6 +123,7 @@ class TimesFmPositionalEmbedding(nn.Module):
123
123
  super().__init__()
124
124
  min_timescale = config.min_timescale
125
125
  max_timescale = config.max_timescale
126
+ self.min_timescale, self.max_timescale = min_timescale, max_timescale
126
127
  self.embedding_dims = config.hidden_size
127
128
 
128
129
  num_timescales = self.embedding_dims // 2
@@ -269,6 +270,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
269
270
  if isinstance(module, TimesFmAttention):
270
271
  # Initialize scaling parameter
271
272
  init.ones_(module.scaling)
273
+ elif isinstance(module, TimesFmPositionalEmbedding):
274
+ num_timescales = module.embedding_dims // 2
275
+ max_timescale, min_timescale = module.max_timescale, module.min_timescale
276
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
277
+ num_timescales - 1, 1
278
+ )
279
+ init.copy_(
280
+ module.inv_timescales,
281
+ min_timescale
282
+ * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
283
+ )
272
284
 
273
285
 
274
286
  @auto_docstring
@@ -317,6 +329,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
317
329
  freq: torch.Tensor,
318
330
  output_attentions: bool = False,
319
331
  output_hidden_states: bool = False,
332
+ **kwargs,
320
333
  ) -> TimesFmOutput:
321
334
  r"""
322
335
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -624,6 +637,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
624
637
  truncate_negative: bool = False,
625
638
  output_attentions: Optional[bool] = None,
626
639
  output_hidden_states: Optional[bool] = None,
640
+ **kwargs,
627
641
  ) -> TimesFmOutputForPrediction:
628
642
  r"""
629
643
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -494,6 +494,7 @@ class TimesformerModel(TimesformerPreTrainedModel):
494
494
  output_attentions: Optional[bool] = None,
495
495
  output_hidden_states: Optional[bool] = None,
496
496
  return_dict: Optional[bool] = None,
497
+ **kwargs,
497
498
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
498
499
  r"""
499
500
  Examples:
@@ -624,6 +625,7 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel):
624
625
  output_attentions: Optional[bool] = None,
625
626
  output_hidden_states: Optional[bool] = None,
626
627
  return_dict: Optional[bool] = None,
628
+ **kwargs,
627
629
  ) -> Union[tuple, ImageClassifierOutput]:
628
630
  r"""
629
631
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -16,10 +16,12 @@
16
16
  from typing import Optional, Union
17
17
 
18
18
  import torch
19
+ from torch import Tensor, nn
19
20
 
21
+ from ... import initialization as init
20
22
  from ...modeling_outputs import BackboneOutput
21
23
  from ...modeling_utils import PreTrainedModel
22
- from ...utils import is_timm_available, is_torch_available, requires_backends
24
+ from ...utils import is_timm_available, requires_backends
23
25
  from ...utils.backbone_utils import BackboneMixin
24
26
  from .configuration_timm_backbone import TimmBackboneConfig
25
27
 
@@ -28,10 +30,6 @@ if is_timm_available():
28
30
  import timm
29
31
 
30
32
 
31
- if is_torch_available():
32
- from torch import Tensor
33
-
34
-
35
33
  class TimmBackbone(PreTrainedModel, BackboneMixin):
36
34
  """
37
35
  Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
@@ -84,10 +82,11 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
84
82
  self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
85
83
  super()._init_backbone(config)
86
84
 
85
+ self.post_init()
86
+
87
87
  @classmethod
88
88
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
89
89
  requires_backends(cls, ["vision", "timm"])
90
- from ...models.timm_backbone import TimmBackboneConfig
91
90
 
92
91
  config = kwargs.pop("config", TimmBackboneConfig())
93
92
 
@@ -116,9 +115,14 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
116
115
 
117
116
  @torch.no_grad()
118
117
  def _init_weights(self, module):
119
- """
120
- Empty init weights function to ensure compatibility of the class in the library.
121
- """
118
+ """We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we
119
+ assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)"""
120
+ if hasattr(module, "init_non_persistent_buffers"):
121
+ module.init_non_persistent_buffers()
122
+ elif isinstance(module, nn.BatchNorm2d) and getattr(module, "running_mean", None) is not None:
123
+ init.zeros_(module.running_mean)
124
+ init.ones_(module.running_var)
125
+ init.zeros_(module.num_batches_tracked)
122
126
 
123
127
  def forward(
124
128
  self,