transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -347,6 +347,22 @@ class FuyuProcessor(ProcessorMixin):
347
347
  The tokenizer is a required input.
348
348
  """
349
349
 
350
+ @classmethod
351
+ def _load_tokenizer_from_pretrained(
352
+ cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
353
+ ):
354
+ """
355
+ Override for BC. Fuyu uses TokenizersBackend and requires token_type_ids to be removed from model_input_names
356
+ because Fuyu uses mm_token_type_ids instead for multimodal token identification. `
357
+ """
358
+ from ...tokenization_utils_tokenizers import TokenizersBackend
359
+
360
+ tokenizer = TokenizersBackend.from_pretrained(pretrained_model_name_or_path, **kwargs)
361
+ # Remove token_type_ids as Fuyu uses mm_token_type_ids instead
362
+ if "token_type_ids" in tokenizer.model_input_names:
363
+ tokenizer.model_input_names.remove("token_type_ids")
364
+ return tokenizer
365
+
350
366
  def __init__(self, image_processor, tokenizer, **kwargs):
351
367
  super().__init__(image_processor=image_processor, tokenizer=tokenizer)
352
368
  self.image_processor = image_processor
@@ -98,7 +98,7 @@ class GemmaRotaryEmbedding(nn.Module):
98
98
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
99
99
 
100
100
  self.register_buffer("inv_freq", inv_freq, persistent=False)
101
- self.original_inv_freq = inv_freq
101
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
102
102
 
103
103
  @staticmethod
104
104
  def compute_default_rope_parameters(
@@ -410,16 +410,14 @@ class GemmaModel(GemmaPreTrainedModel):
410
410
  if position_ids is None:
411
411
  position_ids = cache_position.unsqueeze(0)
412
412
 
413
- # It may already have been prepared by e.g. `generate`
414
- if not isinstance(causal_mask_mapping := attention_mask, dict):
415
- causal_mask_mapping = create_causal_mask(
416
- config=self.config,
417
- input_embeds=inputs_embeds,
418
- attention_mask=attention_mask,
419
- cache_position=cache_position,
420
- past_key_values=past_key_values,
421
- position_ids=position_ids,
422
- )
413
+ causal_mask = create_causal_mask(
414
+ config=self.config,
415
+ input_embeds=inputs_embeds,
416
+ attention_mask=attention_mask,
417
+ cache_position=cache_position,
418
+ past_key_values=past_key_values,
419
+ position_ids=position_ids,
420
+ )
423
421
 
424
422
  # embed positions
425
423
  hidden_states = inputs_embeds
@@ -434,7 +432,7 @@ class GemmaModel(GemmaPreTrainedModel):
434
432
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
435
433
  hidden_states = decoder_layer(
436
434
  hidden_states,
437
- attention_mask=causal_mask_mapping,
435
+ attention_mask=causal_mask,
438
436
  position_ids=position_ids,
439
437
  past_key_values=past_key_values,
440
438
  use_cache=use_cache,
@@ -267,16 +267,14 @@ class GemmaModel(LlamaModel):
267
267
  if position_ids is None:
268
268
  position_ids = cache_position.unsqueeze(0)
269
269
 
270
- # It may already have been prepared by e.g. `generate`
271
- if not isinstance(causal_mask_mapping := attention_mask, dict):
272
- causal_mask_mapping = create_causal_mask(
273
- config=self.config,
274
- input_embeds=inputs_embeds,
275
- attention_mask=attention_mask,
276
- cache_position=cache_position,
277
- past_key_values=past_key_values,
278
- position_ids=position_ids,
279
- )
270
+ causal_mask = create_causal_mask(
271
+ config=self.config,
272
+ input_embeds=inputs_embeds,
273
+ attention_mask=attention_mask,
274
+ cache_position=cache_position,
275
+ past_key_values=past_key_values,
276
+ position_ids=position_ids,
277
+ )
280
278
 
281
279
  # embed positions
282
280
  hidden_states = inputs_embeds
@@ -291,7 +289,7 @@ class GemmaModel(LlamaModel):
291
289
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
292
290
  hidden_states = decoder_layer(
293
291
  hidden_states,
294
- attention_mask=causal_mask_mapping,
292
+ attention_mask=causal_mask,
295
293
  position_ids=position_ids,
296
294
  past_key_values=past_key_values,
297
295
  use_cache=use_cache,
@@ -99,7 +99,7 @@ class Gemma2RotaryEmbedding(nn.Module):
99
99
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
100
100
 
101
101
  self.register_buffer("inv_freq", inv_freq, persistent=False)
102
- self.original_inv_freq = inv_freq
102
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
103
103
 
104
104
  @staticmethod
105
105
  def compute_default_rope_parameters(
@@ -244,7 +244,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
244
244
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
245
245
 
246
246
  self.register_buffer("inv_freq", inv_freq, persistent=False)
247
- self.original_inv_freq = inv_freq
247
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
248
248
 
249
249
  @torch.no_grad()
250
250
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -231,7 +231,6 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
231
231
  processed_images_grouped[shape] = stacked_images
232
232
 
233
233
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
234
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
235
234
  return BatchFeature(
236
235
  data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
237
236
  )
@@ -100,6 +100,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
100
100
 
101
101
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
102
102
  super().__init__(num_embeddings, embedding_dim, padding_idx)
103
+ self.scalar_embed_scale = embed_scale
103
104
  self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
104
105
 
105
106
  def forward(self, input_ids: torch.Tensor):
@@ -165,7 +166,7 @@ class Gemma3RotaryEmbedding(nn.Module):
165
166
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
166
167
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
167
168
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
168
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
169
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
169
170
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
170
171
 
171
172
  @staticmethod
@@ -468,6 +469,16 @@ class Gemma3PreTrainedModel(PreTrainedModel):
468
469
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
469
470
  elif "RMSNorm" in module.__class__.__name__:
470
471
  init.zeros_(module.weight)
472
+ elif isinstance(module, Gemma3TextScaledWordEmbedding):
473
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
474
+ elif isinstance(module, Gemma3RotaryEmbedding):
475
+ for layer_type in module.layer_types:
476
+ rope_init_fn = module.compute_default_rope_parameters
477
+ if module.rope_type[layer_type] != "default":
478
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
479
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
480
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
481
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
471
482
 
472
483
 
473
484
  def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
@@ -754,6 +765,7 @@ def create_causal_mask_mapping(
754
765
  token_type_ids: Optional[torch.Tensor] = None,
755
766
  pixel_values: Optional[torch.FloatTensor] = None,
756
767
  is_training: bool = False,
768
+ is_first_iteration: Optional[bool] = None,
757
769
  **kwargs,
758
770
  ) -> dict:
759
771
  """
@@ -776,8 +788,12 @@ def create_causal_mask_mapping(
776
788
  # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
777
789
  # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
778
790
  # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
779
- may_have_image_input = past_key_values is None or not past_key_values.is_initialized or pixel_values is not None
780
- if token_type_ids is not None and may_have_image_input:
791
+ is_first_iteration = (
792
+ is_first_iteration
793
+ if is_first_iteration is not None
794
+ else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
795
+ )
796
+ if token_type_ids is not None and is_first_iteration:
781
797
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
782
798
  # undo the causal masking)
783
799
 
@@ -1123,6 +1139,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1123
1139
  use_cache=True,
1124
1140
  logits_to_keep=None,
1125
1141
  labels=None,
1142
+ is_first_iteration=False,
1126
1143
  **kwargs,
1127
1144
  ):
1128
1145
  # Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -1136,12 +1153,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1136
1153
  use_cache=use_cache,
1137
1154
  logits_to_keep=logits_to_keep,
1138
1155
  token_type_ids=token_type_ids,
1156
+ is_first_iteration=is_first_iteration,
1139
1157
  **kwargs,
1140
1158
  )
1141
1159
 
1142
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1143
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1144
- if cache_position[0] == 0:
1160
+ # Pixel values are used only in the first iteration if available
1161
+ # In subsquent iterations, they are already merged with text and cached
1162
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
1163
+ # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
1164
+ if is_first_iteration or not use_cache:
1145
1165
  model_inputs["pixel_values"] = pixel_values
1146
1166
 
1147
1167
  return model_inputs
@@ -1155,6 +1175,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1155
1175
  past_key_values: Optional[Cache],
1156
1176
  position_ids: Optional[torch.Tensor],
1157
1177
  token_type_ids: Optional[torch.Tensor] = None,
1178
+ is_first_iteration: Optional[bool] = False,
1158
1179
  **kwargs,
1159
1180
  ) -> dict:
1160
1181
  # Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
@@ -1166,7 +1187,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1166
1187
  past_key_values,
1167
1188
  position_ids,
1168
1189
  token_type_ids,
1169
- pixel_values=kwargs.get("pixel_values"),
1190
+ is_first_iteration=is_first_iteration,
1170
1191
  **{k: v for k, v in kwargs.items() if k != "pixel_values"},
1171
1192
  )
1172
1193
 
@@ -352,6 +352,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
352
352
 
353
353
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
354
354
  super().__init__(num_embeddings, embedding_dim, padding_idx)
355
+ self.scalar_embed_scale = embed_scale
355
356
  self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
356
357
 
357
358
  def forward(self, input_ids: torch.Tensor):
@@ -389,7 +390,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
389
390
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
390
391
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
391
392
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
392
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
393
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
393
394
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
394
395
 
395
396
  @staticmethod
@@ -576,6 +577,16 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
576
577
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
577
578
  elif "RMSNorm" in module.__class__.__name__:
578
579
  init.zeros_(module.weight)
580
+ elif isinstance(module, Gemma3TextScaledWordEmbedding):
581
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
582
+ elif isinstance(module, Gemma3RotaryEmbedding):
583
+ for layer_type in module.layer_types:
584
+ rope_init_fn = module.compute_default_rope_parameters
585
+ if module.rope_type[layer_type] != "default":
586
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
587
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
588
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
589
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
579
590
 
580
591
 
581
592
  def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
@@ -734,6 +745,7 @@ def create_causal_mask_mapping(
734
745
  token_type_ids: Optional[torch.Tensor] = None,
735
746
  pixel_values: Optional[torch.FloatTensor] = None,
736
747
  is_training: bool = False,
748
+ is_first_iteration: Optional[bool] = None,
737
749
  **kwargs,
738
750
  ) -> dict:
739
751
  """
@@ -756,8 +768,12 @@ def create_causal_mask_mapping(
756
768
  # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
757
769
  # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
758
770
  # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
759
- may_have_image_input = past_key_values is None or not past_key_values.is_initialized or pixel_values is not None
760
- if token_type_ids is not None and may_have_image_input:
771
+ is_first_iteration = (
772
+ is_first_iteration
773
+ if is_first_iteration is not None
774
+ else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
775
+ )
776
+ if token_type_ids is not None and is_first_iteration:
761
777
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
762
778
  # undo the causal masking)
763
779
 
@@ -1005,6 +1021,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
1005
1021
  use_cache=True,
1006
1022
  logits_to_keep=None,
1007
1023
  labels=None,
1024
+ is_first_iteration=False,
1008
1025
  **kwargs,
1009
1026
  ):
1010
1027
  # Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -1018,12 +1035,15 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
1018
1035
  use_cache=use_cache,
1019
1036
  logits_to_keep=logits_to_keep,
1020
1037
  token_type_ids=token_type_ids,
1038
+ is_first_iteration=is_first_iteration,
1021
1039
  **kwargs,
1022
1040
  )
1023
1041
 
1024
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1025
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1026
- if cache_position[0] == 0:
1042
+ # Pixel values are used only in the first iteration if available
1043
+ # In subsquent iterations, they are already merged with text and cached
1044
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
1045
+ # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
1046
+ if is_first_iteration or not use_cache:
1027
1047
  model_inputs["pixel_values"] = pixel_values
1028
1048
 
1029
1049
  return model_inputs
@@ -495,6 +495,9 @@ class Gemma3nVisionConfig(PreTrainedConfig):
495
495
 
496
496
  @classmethod
497
497
  def from_dict(cls, config_dict: dict[str, Any], **kwargs):
498
+ # Create a copy to avoid mutating the original dict
499
+ config_dict = config_dict.copy()
500
+
498
501
  label_names = config_dict.get("label_names")
499
502
  is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
500
503
 
@@ -329,6 +329,16 @@ class Gemma3nAudioAttention(nn.Module):
329
329
  r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
330
330
  self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
331
331
 
332
+ local_causal_valid_mask = self.create_local_causal_valid_mask()
333
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
334
+
335
+ self.register_buffer(
336
+ "softcap",
337
+ torch.tensor(self.attention_logits_soft_cap).float(),
338
+ persistent=False,
339
+ )
340
+
341
+ def create_local_causal_valid_mask(self):
332
342
  lower_causal_mask = torch.tril(
333
343
  torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
334
344
  diagonal=0,
@@ -339,13 +349,7 @@ class Gemma3nAudioAttention(nn.Module):
339
349
  )
340
350
  local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
341
351
  local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
342
- self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
343
-
344
- self.register_buffer(
345
- "softcap",
346
- torch.tensor(self.attention_logits_soft_cap).float(),
347
- persistent=False,
348
- )
352
+ return local_causal_valid_mask
349
353
 
350
354
  def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
351
355
  batch, _, *tail_shape = x.shape
@@ -919,6 +923,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
919
923
  self.conformer = nn.ModuleList(
920
924
  [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
921
925
  )
926
+ self.post_init()
922
927
 
923
928
  def forward(
924
929
  self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
@@ -983,6 +988,7 @@ class Gemma3nTextScaledWordEmbedding(nn.Embedding):
983
988
 
984
989
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
985
990
  super().__init__(num_embeddings, embedding_dim, padding_idx)
991
+ self.scalar_embed_scale = embed_scale
986
992
  self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
987
993
 
988
994
  def forward(self, input_ids: torch.Tensor):
@@ -1449,8 +1455,38 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
1449
1455
  init.ones_(module.weight)
1450
1456
  elif isinstance(module, Gemma3nAudioAttention):
1451
1457
  init.zeros_(module.per_dim_scale)
1458
+ q_scale = module.head_dim**-0.5
1459
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
1460
+ init.copy_(module.q_scale, q_scale * r_softplus_0)
1461
+ init.constant_(module.softcap, module.attention_logits_soft_cap)
1462
+ init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
1463
+ elif isinstance(module, Gemma3nTextScaledWordEmbedding):
1464
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
1452
1465
  elif isinstance(module, Gemma3nTextAltUp):
1453
1466
  init.zeros_(module.correct_output_scale)
1467
+ init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
1468
+ elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
1469
+ min_timescale, max_timescale = 1.0, 1.0e4
1470
+ num_timescales = module.channels // 2
1471
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
1472
+ num_timescales - 1, 1
1473
+ )
1474
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
1475
+ init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
1476
+ elif isinstance(module, Gemma3nTextModel):
1477
+ init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
1478
+ init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
1479
+ elif isinstance(module, Gemma3nRotaryEmbedding):
1480
+ for layer_type in module.layer_types:
1481
+ rope_init_fn = module.compute_default_rope_parameters
1482
+ if module.rope_type[layer_type] != "default":
1483
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
1484
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
1485
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
1486
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
1487
+
1488
+ if hasattr(module, "gradient_clipping"):
1489
+ init.constant_(module.gradient_clipping, self.config.gradient_clipping)
1454
1490
 
1455
1491
 
1456
1492
  class Gemma3nRotaryEmbedding(nn.Module):
@@ -1476,7 +1512,7 @@ class Gemma3nRotaryEmbedding(nn.Module):
1476
1512
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
1477
1513
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
1478
1514
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
1479
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
1515
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
1480
1516
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
1481
1517
 
1482
1518
  @staticmethod
@@ -2301,6 +2337,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2301
2337
  use_cache=True,
2302
2338
  logits_to_keep=None,
2303
2339
  labels=None,
2340
+ is_first_iteration=False,
2304
2341
  **kwargs,
2305
2342
  ):
2306
2343
  # Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -2314,13 +2351,14 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2314
2351
  use_cache=use_cache,
2315
2352
  logits_to_keep=logits_to_keep,
2316
2353
  token_type_ids=token_type_ids,
2354
+ is_first_iteration=is_first_iteration,
2317
2355
  **kwargs,
2318
2356
  )
2319
2357
 
2320
2358
  # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
2321
2359
  # tokens anymore. Otherwise multimodal inputs should be passed to model.
2322
2360
  # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
2323
- if cache_position[0] == 0:
2361
+ if is_first_iteration or not use_cache:
2324
2362
  model_inputs["pixel_values"] = pixel_values
2325
2363
  model_inputs["input_features"] = input_features
2326
2364
  model_inputs["input_features_mask"] = input_features_mask
@@ -27,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache
27
27
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
28
28
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
29
29
  from ...modeling_outputs import BaseModelOutputWithPast
30
- from ...modeling_rope_utils import RopeParameters
30
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
31
31
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
32
  from ...processing_utils import Unpack
33
33
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
@@ -45,6 +45,7 @@ from ..gemma3.modeling_gemma3 import (
45
45
  Gemma3DecoderLayer,
46
46
  Gemma3ForCausalLM,
47
47
  Gemma3RMSNorm,
48
+ Gemma3RotaryEmbedding,
48
49
  Gemma3TextModel,
49
50
  Gemma3TextScaledWordEmbedding,
50
51
  )
@@ -882,6 +883,16 @@ class Gemma3nAudioAttention(nn.Module):
882
883
  r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
883
884
  self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
884
885
 
886
+ local_causal_valid_mask = self.create_local_causal_valid_mask()
887
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
888
+
889
+ self.register_buffer(
890
+ "softcap",
891
+ torch.tensor(self.attention_logits_soft_cap).float(),
892
+ persistent=False,
893
+ )
894
+
895
+ def create_local_causal_valid_mask(self):
885
896
  lower_causal_mask = torch.tril(
886
897
  torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
887
898
  diagonal=0,
@@ -892,13 +903,7 @@ class Gemma3nAudioAttention(nn.Module):
892
903
  )
893
904
  local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
894
905
  local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
895
- self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
896
-
897
- self.register_buffer(
898
- "softcap",
899
- torch.tensor(self.attention_logits_soft_cap).float(),
900
- persistent=False,
901
- )
906
+ return local_causal_valid_mask
902
907
 
903
908
  def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
904
909
  batch, _, *tail_shape = x.shape
@@ -1472,6 +1477,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
1472
1477
  self.conformer = nn.ModuleList(
1473
1478
  [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
1474
1479
  )
1480
+ self.post_init()
1475
1481
 
1476
1482
  def forward(
1477
1483
  self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
@@ -1892,8 +1898,42 @@ class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
1892
1898
  init.ones_(module.weight)
1893
1899
  elif isinstance(module, Gemma3nAudioAttention):
1894
1900
  init.zeros_(module.per_dim_scale)
1901
+ q_scale = module.head_dim**-0.5
1902
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
1903
+ init.copy_(module.q_scale, q_scale * r_softplus_0)
1904
+ init.constant_(module.softcap, module.attention_logits_soft_cap)
1905
+ init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
1906
+ elif isinstance(module, Gemma3nTextScaledWordEmbedding):
1907
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
1895
1908
  elif isinstance(module, Gemma3nTextAltUp):
1896
1909
  init.zeros_(module.correct_output_scale)
1910
+ init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
1911
+ elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
1912
+ min_timescale, max_timescale = 1.0, 1.0e4
1913
+ num_timescales = module.channels // 2
1914
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
1915
+ num_timescales - 1, 1
1916
+ )
1917
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
1918
+ init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
1919
+ elif isinstance(module, Gemma3nTextModel):
1920
+ init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
1921
+ init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
1922
+ elif isinstance(module, Gemma3nRotaryEmbedding):
1923
+ for layer_type in module.layer_types:
1924
+ rope_init_fn = module.compute_default_rope_parameters
1925
+ if module.rope_type[layer_type] != "default":
1926
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
1927
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
1928
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
1929
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
1930
+
1931
+ if hasattr(module, "gradient_clipping"):
1932
+ init.constant_(module.gradient_clipping, self.config.gradient_clipping)
1933
+
1934
+
1935
+ class Gemma3nRotaryEmbedding(Gemma3RotaryEmbedding):
1936
+ pass
1897
1937
 
1898
1938
 
1899
1939
  @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
@@ -2543,6 +2583,7 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
2543
2583
  use_cache=True,
2544
2584
  logits_to_keep=None,
2545
2585
  labels=None,
2586
+ is_first_iteration=False,
2546
2587
  **kwargs,
2547
2588
  ):
2548
2589
  # Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -2556,13 +2597,14 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
2556
2597
  use_cache=use_cache,
2557
2598
  logits_to_keep=logits_to_keep,
2558
2599
  token_type_ids=token_type_ids,
2600
+ is_first_iteration=is_first_iteration,
2559
2601
  **kwargs,
2560
2602
  )
2561
2603
 
2562
2604
  # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
2563
2605
  # tokens anymore. Otherwise multimodal inputs should be passed to model.
2564
2606
  # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
2565
- if cache_position[0] == 0:
2607
+ if is_first_iteration or not use_cache:
2566
2608
  model_inputs["pixel_values"] = pixel_values
2567
2609
  model_inputs["input_features"] = input_features
2568
2610
  model_inputs["input_features_mask"] = input_features_mask