transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -421,40 +421,8 @@ class Swinv2SelfAttention(nn.Module):
421
421
  nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
422
422
  )
423
423
 
424
- # get relative_coords_table
425
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
426
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
427
- relative_coords_table = (
428
- torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
429
- .permute(1, 2, 0)
430
- .contiguous()
431
- .unsqueeze(0)
432
- ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
433
- if pretrained_window_size[0] > 0:
434
- relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
435
- relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
436
- elif window_size > 1:
437
- relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
438
- relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
439
- relative_coords_table *= 8 # normalize to -8, 8
440
- relative_coords_table = (
441
- torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
442
- )
443
- # set to same dtype as mlp weight
444
- relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
424
+ relative_coords_table, relative_position_index = self.create_coords_table_and_index()
445
425
  self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
446
-
447
- # get pair-wise relative position index for each token inside the window
448
- coords_h = torch.arange(self.window_size[0])
449
- coords_w = torch.arange(self.window_size[1])
450
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
451
- coords_flatten = torch.flatten(coords, 1)
452
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
453
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
454
- relative_coords[:, :, 0] += self.window_size[0] - 1
455
- relative_coords[:, :, 1] += self.window_size[1] - 1
456
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
457
- relative_position_index = relative_coords.sum(-1)
458
426
  self.register_buffer("relative_position_index", relative_position_index, persistent=False)
459
427
 
460
428
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -530,6 +498,43 @@ class Swinv2SelfAttention(nn.Module):
530
498
 
531
499
  return outputs
532
500
 
501
+ def create_coords_table_and_index(self):
502
+ # get relative_coords_table
503
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
504
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
505
+ relative_coords_table = (
506
+ torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
507
+ .permute(1, 2, 0)
508
+ .contiguous()
509
+ .unsqueeze(0)
510
+ ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
511
+ if self.pretrained_window_size[0] > 0:
512
+ relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
513
+ relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
514
+ elif self.window_size[0] > 1:
515
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
516
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
517
+ relative_coords_table *= 8 # normalize to -8, 8
518
+ relative_coords_table = (
519
+ torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
520
+ )
521
+ # set to same dtype as mlp weight
522
+ relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
523
+
524
+ # get pair-wise relative position index for each token inside the window
525
+ coords_h = torch.arange(self.window_size[0])
526
+ coords_w = torch.arange(self.window_size[1])
527
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
528
+ coords_flatten = torch.flatten(coords, 1)
529
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
530
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
531
+ relative_coords[:, :, 0] += self.window_size[0] - 1
532
+ relative_coords[:, :, 1] += self.window_size[1] - 1
533
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
534
+ relative_position_index = relative_coords.sum(-1)
535
+
536
+ return relative_coords_table, relative_position_index
537
+
533
538
 
534
539
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2
535
540
  class Swinv2SelfOutput(nn.Module):
@@ -904,6 +909,9 @@ class Swinv2PreTrainedModel(PreTrainedModel):
904
909
  init.zeros_(module.position_embeddings)
905
910
  elif isinstance(module, Swinv2SelfAttention):
906
911
  init.constant_(module.logit_scale, math.log(10))
912
+ relative_coords_table, relative_position_index = module.create_coords_table_and_index()
913
+ init.copy_(module.relative_coords_table, relative_coords_table)
914
+ init.copy_(module.relative_position_index, relative_position_index)
907
915
 
908
916
 
909
917
  @auto_docstring
@@ -111,7 +111,7 @@ class SwitchTransformersTop1Router(nn.Module):
111
111
  router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
112
112
  expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
113
113
  token_priority = torch.cumsum(expert_index, dim=-2)
114
- # mask if the token routed to to the expert will overflow
114
+ # mask if the token routed to the expert will overflow
115
115
  expert_capacity_mask = token_priority <= self.expert_capacity
116
116
  expert_index = expert_index * expert_capacity_mask
117
117
  router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
@@ -913,6 +913,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
913
913
  "encoder.embed_tokens.weight": "shared.weight",
914
914
  "decoder.embed_tokens.weight": "shared.weight",
915
915
  }
916
+ _input_embed_layer = "shared"
916
917
 
917
918
  def __init__(self, config: SwitchTransformersConfig):
918
919
  super().__init__(config)
@@ -921,20 +922,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
921
922
  encoder_config = copy.deepcopy(config)
922
923
  encoder_config.is_decoder = False
923
924
  encoder_config.use_cache = False
924
- encoder_config.tie_encoder_decoder = False
925
925
  self.encoder = SwitchTransformersStack(encoder_config)
926
926
 
927
927
  decoder_config = copy.deepcopy(config)
928
928
  decoder_config.is_decoder = True
929
- decoder_config.tie_encoder_decoder = False
930
929
  self.decoder = SwitchTransformersStack(decoder_config)
931
930
 
932
931
  # Initialize weights and apply final processing
933
932
  self.post_init()
934
933
 
935
- def get_input_embeddings(self):
936
- return self.shared
937
-
938
934
  def set_input_embeddings(self, new_embeddings):
939
935
  self.shared = new_embeddings
940
936
  self.encoder.set_input_embeddings(new_embeddings)
@@ -1072,12 +1068,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
1072
1068
  encoder_config = copy.deepcopy(config)
1073
1069
  encoder_config.is_decoder = False
1074
1070
  encoder_config.use_cache = False
1075
- encoder_config.tie_encoder_decoder = False
1076
1071
  self.encoder = SwitchTransformersStack(encoder_config)
1077
1072
 
1078
1073
  decoder_config = copy.deepcopy(config)
1079
1074
  decoder_config.is_decoder = True
1080
- decoder_config.tie_encoder_decoder = False
1081
1075
  decoder_config.num_layers = config.num_decoder_layers
1082
1076
  self.decoder = SwitchTransformersStack(decoder_config)
1083
1077
 
@@ -170,7 +170,7 @@ class SwitchTransformersTop1Router(nn.Module):
170
170
  router_logits, expert_index = torch.max(router_probs, dim=-1, keepdim=True)
171
171
  expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
172
172
  token_priority = torch.cumsum(expert_index, dim=-2)
173
- # mask if the token routed to to the expert will overflow
173
+ # mask if the token routed to the expert will overflow
174
174
  expert_capacity_mask = token_priority <= self.expert_capacity
175
175
  expert_index = expert_index * expert_capacity_mask
176
176
  router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
@@ -669,6 +669,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
669
669
  "encoder.embed_tokens.weight": "shared.weight",
670
670
  "decoder.embed_tokens.weight": "shared.weight",
671
671
  }
672
+ _input_embed_layer = "shared"
672
673
 
673
674
  def __init__(self, config: SwitchTransformersConfig):
674
675
  super().__init__(config)
@@ -677,20 +678,15 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
677
678
  encoder_config = copy.deepcopy(config)
678
679
  encoder_config.is_decoder = False
679
680
  encoder_config.use_cache = False
680
- encoder_config.tie_encoder_decoder = False
681
681
  self.encoder = SwitchTransformersStack(encoder_config)
682
682
 
683
683
  decoder_config = copy.deepcopy(config)
684
684
  decoder_config.is_decoder = True
685
- decoder_config.tie_encoder_decoder = False
686
685
  self.decoder = SwitchTransformersStack(decoder_config)
687
686
 
688
687
  # Initialize weights and apply final processing
689
688
  self.post_init()
690
689
 
691
- def get_input_embeddings(self):
692
- return self.shared
693
-
694
690
  def set_input_embeddings(self, new_embeddings):
695
691
  self.shared = new_embeddings
696
692
  self.encoder.set_input_embeddings(new_embeddings)
@@ -763,12 +759,10 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
763
759
  encoder_config = copy.deepcopy(config)
764
760
  encoder_config.is_decoder = False
765
761
  encoder_config.use_cache = False
766
- encoder_config.tie_encoder_decoder = False
767
762
  self.encoder = SwitchTransformersStack(encoder_config)
768
763
 
769
764
  decoder_config = copy.deepcopy(config)
770
765
  decoder_config.is_decoder = True
771
- decoder_config.tie_encoder_decoder = False
772
766
  decoder_config.num_layers = config.num_decoder_layers
773
767
  self.decoder = SwitchTransformersStack(decoder_config)
774
768
 
@@ -131,13 +131,19 @@ class T5Config(PreTrainedConfig):
131
131
  if feed_forward_proj == "gated-gelu":
132
132
  self.dense_act_fn = "gelu_new"
133
133
 
134
+ # Super weird feature of T5 because we support T5 and T51.1 from the same
135
+ # model code. Original T5 always scaled outputs, but the 1.1v does not.
136
+ # The model code was relying on saved configs where `tie_word_embeddings` is
137
+ # set to `False` in 1.1v and using it as indicator of whether to scale or not
138
+ # But in fact we tie weights always and force it to be `True`
139
+ self.scale_decoder_outputs = kwargs.get("tie_word_embeddings") is not False
140
+ kwargs["tie_word_embeddings"] = True
134
141
  super().__init__(
135
142
  pad_token_id=pad_token_id,
136
143
  eos_token_id=eos_token_id,
137
144
  is_encoder_decoder=is_encoder_decoder,
138
145
  **kwargs,
139
146
  )
140
- self.tie_encoder_decoder = True # T5 is always tied, has always been like that.
141
147
 
142
148
 
143
149
  __all__ = ["T5Config"]
@@ -844,12 +844,10 @@ class T5Model(T5PreTrainedModel):
844
844
  encoder_config = copy.deepcopy(config)
845
845
  encoder_config.is_decoder = False
846
846
  encoder_config.use_cache = False
847
- encoder_config.tie_encoder_decoder = False
848
847
  self.encoder = T5Stack(encoder_config)
849
848
 
850
849
  decoder_config = copy.deepcopy(config)
851
850
  decoder_config.is_decoder = True
852
- decoder_config.tie_encoder_decoder = False
853
851
  decoder_config.num_layers = config.num_decoder_layers
854
852
  self.decoder = T5Stack(decoder_config)
855
853
 
@@ -1007,12 +1005,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
1007
1005
  encoder_config = copy.deepcopy(config)
1008
1006
  encoder_config.is_decoder = False
1009
1007
  encoder_config.use_cache = False
1010
- encoder_config.tie_encoder_decoder = False
1011
1008
  self.encoder = T5Stack(encoder_config)
1012
1009
 
1013
1010
  decoder_config = copy.deepcopy(config)
1014
1011
  decoder_config.is_decoder = True
1015
- decoder_config.tie_encoder_decoder = False
1016
1012
  decoder_config.num_layers = config.num_decoder_layers
1017
1013
  self.decoder = T5Stack(decoder_config)
1018
1014
 
@@ -1147,7 +1143,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
1147
1143
 
1148
1144
  sequence_output = decoder_outputs[0]
1149
1145
 
1150
- if self.config.tie_word_embeddings:
1146
+ if self.config.scale_decoder_outputs:
1151
1147
  sequence_output = sequence_output * (self.model_dim**-0.5)
1152
1148
 
1153
1149
  lm_logits = self.lm_head(sequence_output)
@@ -1487,12 +1483,10 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
1487
1483
  encoder_config = copy.deepcopy(config)
1488
1484
  encoder_config.is_decoder = False
1489
1485
  encoder_config.use_cache = False
1490
- encoder_config.tie_encoder_decoder = False
1491
1486
  self.encoder = T5Stack(encoder_config)
1492
1487
 
1493
1488
  decoder_config = copy.deepcopy(config)
1494
1489
  decoder_config.is_decoder = True
1495
- decoder_config.tie_encoder_decoder = False
1496
1490
  decoder_config.num_layers = config.num_decoder_layers
1497
1491
  self.decoder = T5Stack(decoder_config)
1498
1492
 
@@ -108,7 +108,7 @@ class T5GemmaRotaryEmbedding(nn.Module):
108
108
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
109
109
 
110
110
  self.register_buffer("inv_freq", inv_freq, persistent=False)
111
- self.original_inv_freq = inv_freq
111
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
112
112
 
113
113
  @staticmethod
114
114
  def compute_default_rope_parameters(
@@ -32,9 +32,9 @@ logger = logging.get_logger(__name__)
32
32
 
33
33
  class T5Gemma2TextConfig(PreTrainedConfig):
34
34
  r"""
35
- This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate an T5Gemma2Text
36
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37
- defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
35
+ This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate the encoder's
36
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
37
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Text-7B.
38
38
  e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
39
39
  Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
40
40
  documentation from [`PreTrainedConfig`] for more information.
@@ -99,19 +99,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
99
99
  Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
100
100
  a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
101
101
  with longer `max_position_embeddings`.
102
- use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
103
- If True, the model will attend to all text tokens instead of using a causal mask. This does not change
104
- behavior for vision tokens.
105
-
106
- ```python
107
- >>> from transformers import T5Gemma2TextModel, T5Gemma2TextConfig
108
- >>> # Initializing a T5Gemma2Text t5gemma2_text-7b style configuration
109
- >>> configuration = T5Gemma2TextConfig()
110
- >>> # Initializing a model from the t5gemma2_text-7b style configuration
111
- >>> model = T5Gemma2TextModel(configuration)
112
- >>> # Accessing the model configuration
113
- >>> configuration = model.config
114
- ```
115
102
  """
116
103
 
117
104
  model_type = "t5gemma2_text"
@@ -158,7 +145,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
158
145
  final_logit_softcapping: Optional[float] = None,
159
146
  attn_logit_softcapping: Optional[float] = None,
160
147
  rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
161
- use_bidirectional_attention: Optional[bool] = False,
162
148
  **kwargs,
163
149
  ):
164
150
  self.vocab_size = vocab_size
@@ -181,10 +167,6 @@ class T5Gemma2TextConfig(PreTrainedConfig):
181
167
  self.attn_logit_softcapping = attn_logit_softcapping
182
168
  self.layer_types = layer_types
183
169
 
184
- self.use_bidirectional_attention = use_bidirectional_attention
185
- if use_bidirectional_attention:
186
- self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
187
-
188
170
  # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
189
171
  self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
190
172
 
@@ -326,9 +308,9 @@ class T5Gemma2EncoderConfig(PreTrainedConfig):
326
308
 
327
309
  class T5Gemma2DecoderConfig(PreTrainedConfig):
328
310
  r"""
329
- This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate an T5Gemma2Decoder
330
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
331
- defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
311
+ This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate the decoder
312
+ text model portion of the T5Gemma2 Model according to the specified arguments, defining the model architecture. Instantiating
313
+ a configuration with the defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B.
332
314
  e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b)
333
315
  Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
334
316
  documentation from [`PreTrainedConfig`] for more information.
@@ -393,19 +375,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
393
375
  Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
394
376
  a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
395
377
  with longer `max_position_embeddings`.
396
- use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
397
- If True, the model will attend to all text tokens instead of using a causal mask. This does not change
398
- behavior for vision tokens.
399
-
400
- ```python
401
- >>> from transformers import T5Gemma2DecoderModel, T5Gemma2DecoderConfig
402
- >>> # Initializing a T5Gemma2Decoder t5gemma2_text-7b style configuration
403
- >>> configuration = T5Gemma2DecoderConfig()
404
- >>> # Initializing a model from the t5gemma2_text-7b style configuration
405
- >>> model = T5Gemma2DecoderModel(configuration)
406
- >>> # Accessing the model configuration
407
- >>> configuration = model.config
408
- ```
409
378
  """
410
379
 
411
380
  model_type = "t5gemma2_decoder"
@@ -452,7 +421,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
452
421
  final_logit_softcapping: Optional[float] = None,
453
422
  attn_logit_softcapping: Optional[float] = None,
454
423
  rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
455
- use_bidirectional_attention: Optional[bool] = False,
456
424
  **kwargs,
457
425
  ):
458
426
  self.vocab_size = vocab_size
@@ -475,10 +443,6 @@ class T5Gemma2DecoderConfig(PreTrainedConfig):
475
443
  self.attn_logit_softcapping = attn_logit_softcapping
476
444
  self.layer_types = layer_types
477
445
 
478
- self.use_bidirectional_attention = use_bidirectional_attention
479
- if use_bidirectional_attention:
480
- self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
481
-
482
446
  # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
483
447
  self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
484
448
 
@@ -113,7 +113,7 @@ class T5Gemma2RotaryEmbedding(nn.Module):
113
113
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
114
114
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
115
115
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
116
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
116
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
117
117
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
118
118
 
119
119
  @staticmethod
@@ -266,7 +266,7 @@ class T5Gemma2SelfAttention(nn.Module):
266
266
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
267
267
  self.scaling = config.query_pre_attn_scalar**-0.5
268
268
  self.attention_dropout = self.config.attention_dropout
269
- self.is_causal = not self.config.use_bidirectional_attention
269
+ self.is_causal = False # Only used by the encoder
270
270
 
271
271
  self.q_proj = nn.Linear(
272
272
  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
@@ -348,7 +348,7 @@ class T5Gemma2MergedAttention(nn.Module):
348
348
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
349
349
  self.scaling = config.query_pre_attn_scalar**-0.5
350
350
  self.attention_dropout = self.config.attention_dropout
351
- self.is_causal = not self.config.use_bidirectional_attention
351
+ self.is_causal = False # Fused causal and encoder mask
352
352
 
353
353
  self.q_proj = nn.Linear(
354
354
  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
@@ -446,7 +446,6 @@ class T5Gemma2MergedAttention(nn.Module):
446
446
  merged_attention_mask,
447
447
  dropout=self.attention_dropout if self.training else 0.0,
448
448
  scaling=self.scaling,
449
- is_causal=False,
450
449
  **kwargs,
451
450
  )
452
451
 
@@ -649,6 +648,7 @@ class T5Gemma2TextScaledWordEmbedding(nn.Embedding):
649
648
  eoi_token_index: int = 256_000,
650
649
  ):
651
650
  super().__init__(num_embeddings, embedding_dim, padding_idx)
651
+ self.scalar_embed_scale = embed_scale
652
652
  self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
653
653
  self.eoi_token_index = eoi_token_index
654
654
  self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
@@ -700,6 +700,7 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
700
700
  init.zeros_(module.mm_input_projection_weight)
701
701
  elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
702
702
  init.zeros_(module.eoi_embedding)
703
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
703
704
  elif isinstance(module, T5Gemma2ClassificationHead):
704
705
  scale = module.out_proj.weight.shape[0] ** -0.5
705
706
  init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
@@ -708,6 +709,14 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
708
709
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
709
710
  elif "RMSNorm" in module.__class__.__name__:
710
711
  init.zeros_(module.weight)
712
+ elif isinstance(module, T5Gemma2RotaryEmbedding):
713
+ for layer_type in module.layer_types:
714
+ rope_init_fn = module.compute_default_rope_parameters
715
+ if module.rope_type[layer_type] != "default":
716
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
717
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
718
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
719
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
711
720
 
712
721
  def prepare_decoder_input_ids_from_labels(self, input_ids):
713
722
  """